forked from mindspore-Ecosystem/mindspore
!22040 optimize document of probability
Merge pull request !22040 from byweng/master
This commit is contained in:
commit
9db8f3221f
|
@ -147,6 +147,7 @@ class Bijector(Cell):
|
||||||
return (shape_tensor + dist_shape_tensor).shape
|
return (shape_tensor + dist_shape_tensor).shape
|
||||||
|
|
||||||
def shape_mapping(self, shape):
|
def shape_mapping(self, shape):
|
||||||
|
"""Map shape."""
|
||||||
return self._shape_mapping(shape)
|
return self._shape_mapping(shape)
|
||||||
|
|
||||||
def _add_parameter(self, value, name):
|
def _add_parameter(self, value, name):
|
||||||
|
@ -161,7 +162,7 @@ class Bijector(Cell):
|
||||||
self.common_dtype = None
|
self.common_dtype = None
|
||||||
# cast value to a tensor if it is not None
|
# cast value to a tensor if it is not None
|
||||||
if isinstance(value, bool) or value is None:
|
if isinstance(value, bool) or value is None:
|
||||||
raise TypeError(f"{name} cannot be type {type(value)}")
|
raise TypeError("{} cannot be type {}".format(name, type(value)))
|
||||||
value_t = Tensor(value)
|
value_t = Tensor(value)
|
||||||
# if the bijector's dtype is not specified
|
# if the bijector's dtype is not specified
|
||||||
if self.dtype is None:
|
if self.dtype is None:
|
||||||
|
|
|
@ -57,8 +57,9 @@ class Exp(PowerTransform):
|
||||||
super(Exp, self).__init__(name=name)
|
super(Exp, self).__init__(name=name)
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
|
"""Display instance object as string."""
|
||||||
if self.is_scalar_batch:
|
if self.is_scalar_batch:
|
||||||
str_info = 'exp'
|
str_info = 'exp'
|
||||||
else:
|
else:
|
||||||
str_info = f'batch_shape = {self.batch_shape}'
|
str_info = 'batch_shape = {}'.format(self.batch_shape)
|
||||||
return str_info
|
return str_info
|
||||||
|
|
|
@ -28,7 +28,7 @@ class GumbelCDF(Bijector):
|
||||||
Y = \exp(-\exp(\frac{-(X - loc)}{scale}))
|
Y = \exp(-\exp(\frac{-(X - loc)}{scale}))
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
loc (float, list, numpy.ndarray, Tensor): The location. Default: 0..
|
loc (float, list, numpy.ndarray, Tensor): The location. Default: 0.0.
|
||||||
scale (float, list, numpy.ndarray, Tensor): The scale. Default: 1.0.
|
scale (float, list, numpy.ndarray, Tensor): The scale. Default: 1.0.
|
||||||
name (str): The name of the Bijector. Default: 'GumbelCDF'.
|
name (str): The name of the Bijector. Default: 'GumbelCDF'.
|
||||||
|
|
||||||
|
@ -101,10 +101,11 @@ class GumbelCDF(Bijector):
|
||||||
return self._scale
|
return self._scale
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
|
"""Display instance object as string."""
|
||||||
if self.is_scalar_batch:
|
if self.is_scalar_batch:
|
||||||
str_info = f'loc = {self.loc}, scale = {self.scale}'
|
str_info = 'loc = {}, scale = {}'.format(self.loc, self.scale)
|
||||||
else:
|
else:
|
||||||
str_info = f'batch_shape = {self.batch_shape}'
|
str_info = 'batch_shape = {}'.format(self.batch_shape)
|
||||||
return str_info
|
return str_info
|
||||||
|
|
||||||
def _forward(self, x):
|
def _forward(self, x):
|
||||||
|
@ -112,9 +113,12 @@ class GumbelCDF(Bijector):
|
||||||
loc_local = self.cast_param_by_value(x, self.loc)
|
loc_local = self.cast_param_by_value(x, self.loc)
|
||||||
scale_local = self.cast_param_by_value(x, self.scale)
|
scale_local = self.cast_param_by_value(x, self.scale)
|
||||||
z = (x - loc_local) / scale_local
|
z = (x - loc_local) / scale_local
|
||||||
|
# pylint: disable=E1130
|
||||||
return self.exp(-self.exp(-z))
|
return self.exp(-self.exp(-z))
|
||||||
|
|
||||||
def _inverse(self, y):
|
def _inverse(self, y):
|
||||||
|
# pylint false positive
|
||||||
|
# pylint: disable=E1130
|
||||||
y = self._check_value_dtype(y)
|
y = self._check_value_dtype(y)
|
||||||
loc_local = self.cast_param_by_value(y, self.loc)
|
loc_local = self.cast_param_by_value(y, self.loc)
|
||||||
scale_local = self.cast_param_by_value(y, self.scale)
|
scale_local = self.cast_param_by_value(y, self.scale)
|
||||||
|
|
|
@ -23,7 +23,8 @@ class Invert(Bijector):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
bijector (Bijector): Base Bijector.
|
bijector (Bijector): Base Bijector.
|
||||||
name (str): The name of the Bijector. Default: 'Invert' + bijector.name.
|
name (str): The name of the Bijector. Default: "". When name is set to "", it is actually
|
||||||
|
'Invert' + bijector.name.
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``Ascend`` ``GPU``
|
``Ascend`` ``GPU``
|
||||||
|
@ -67,16 +68,29 @@ class Invert(Bijector):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def bijector(self):
|
def bijector(self):
|
||||||
|
"""Return base bijector."""
|
||||||
return self._bijector
|
return self._bijector
|
||||||
|
|
||||||
def inverse(self, y):
|
def inverse(self, y):
|
||||||
|
"""
|
||||||
|
Forward transformation: transform the input value to another distribution.
|
||||||
|
"""
|
||||||
return self.bijector("forward", y)
|
return self.bijector("forward", y)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Inverse transformation: transform the input value back to the original distribution.
|
||||||
|
"""
|
||||||
return self.bijector("inverse", x)
|
return self.bijector("inverse", x)
|
||||||
|
|
||||||
def inverse_log_jacobian(self, y):
|
def inverse_log_jacobian(self, y):
|
||||||
|
"""
|
||||||
|
Logarithm of the derivative of the forward transformation.
|
||||||
|
"""
|
||||||
return self.bijector("forward_log_jacobian", y)
|
return self.bijector("forward_log_jacobian", y)
|
||||||
|
|
||||||
def forward_log_jacobian(self, x):
|
def forward_log_jacobian(self, x):
|
||||||
|
"""
|
||||||
|
Logarithm of the derivative of the inverse transformation.
|
||||||
|
"""
|
||||||
return self.bijector("inverse_log_jacobian", x)
|
return self.bijector("inverse_log_jacobian", x)
|
||||||
|
|
|
@ -95,13 +95,13 @@ class PowerTransform(Bijector):
|
||||||
return self._power
|
return self._power
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
|
"""Display instance object as string."""
|
||||||
if self.is_scalar_batch:
|
if self.is_scalar_batch:
|
||||||
str_info = f'power = {self.power}'
|
str_info = 'power = {}'.format(self.power)
|
||||||
else:
|
else:
|
||||||
str_info = f'batch_shape = {self.batch_shape}'
|
str_info = 'batch_shape = {}'.format(self.batch_shape)
|
||||||
return str_info
|
return str_info
|
||||||
|
|
||||||
|
|
||||||
def _forward(self, x):
|
def _forward(self, x):
|
||||||
"""
|
"""
|
||||||
Evaluate the forward mapping.
|
Evaluate the forward mapping.
|
||||||
|
|
|
@ -101,10 +101,11 @@ class ScalarAffine(Bijector):
|
||||||
return self._shift
|
return self._shift
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
|
"""Display instance object as string."""
|
||||||
if self.is_scalar_batch:
|
if self.is_scalar_batch:
|
||||||
str_info = f'scale = {self.scale}, shift = {self.shift}'
|
str_info = 'scale = {}, shift = {}'.format(self.scale, self.shift)
|
||||||
else:
|
else:
|
||||||
str_info = f'batch_shape = {self.batch_shape}'
|
str_info = 'batch_shape = {}'.format(self.batch_shape)
|
||||||
return str_info
|
return str_info
|
||||||
|
|
||||||
def _forward(self, x):
|
def _forward(self, x):
|
||||||
|
|
|
@ -122,6 +122,7 @@ class Softplus(Bijector):
|
||||||
ones = self.fill(self.dtypeop(x), self.shape(x), 1.0)
|
ones = self.fill(self.dtypeop(x), self.shape(x), 1.0)
|
||||||
too_small_or_too_large = self.logicalor(too_small, too_large)
|
too_small_or_too_large = self.logicalor(too_small, too_large)
|
||||||
x = self.select(too_small_or_too_large, ones, x)
|
x = self.select(too_small_or_too_large, ones, x)
|
||||||
|
# pylint: disable=E1130
|
||||||
y = x + self.log(self.abs(self.expm1(-x)))
|
y = x + self.log(self.abs(self.expm1(-x)))
|
||||||
return self.select(too_small, too_small_value, self.select(too_large, too_large_value, y))
|
return self.select(too_small, too_small_value, self.select(too_large, too_large_value, y))
|
||||||
|
|
||||||
|
@ -130,10 +131,11 @@ class Softplus(Bijector):
|
||||||
return self._sharpness
|
return self._sharpness
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
|
"""Display instance object as string."""
|
||||||
if self.is_scalar_batch:
|
if self.is_scalar_batch:
|
||||||
str_info = f'sharpness = {self.sharpness}'
|
str_info = 'sharpness = {}'.format(self.sharpness)
|
||||||
else:
|
else:
|
||||||
str_info = f'batch_shape = {self.batch_shape}'
|
str_info = 'batch_shape = {}'.format(self.batch_shape)
|
||||||
return str_info
|
return str_info
|
||||||
|
|
||||||
def _forward(self, x):
|
def _forward(self, x):
|
||||||
|
|
|
@ -72,9 +72,7 @@ class _ConvVariational(_Conv):
|
||||||
self.group = group
|
self.group = group
|
||||||
self.has_bias = has_bias
|
self.has_bias = has_bias
|
||||||
|
|
||||||
# distribution trainable parameters
|
self.shape = [self.out_channels, self.in_channels // self.group, *self.kernel_size]
|
||||||
self.shape = [self.out_channels,
|
|
||||||
self.in_channels // self.group, *self.kernel_size]
|
|
||||||
|
|
||||||
self.weight.requires_grad = False
|
self.weight.requires_grad = False
|
||||||
self.weight_prior = check_prior(weight_prior_fn, "weight_prior_fn")
|
self.weight_prior = check_prior(weight_prior_fn, "weight_prior_fn")
|
||||||
|
@ -108,6 +106,7 @@ class _ConvVariational(_Conv):
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
|
"""Display instance object as string."""
|
||||||
s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, pad_mode={}, ' \
|
s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, pad_mode={}, ' \
|
||||||
'padding={}, dilation={}, group={}, weight_mean={}, weight_std={}, has_bias={}' \
|
'padding={}, dilation={}, group={}, weight_mean={}, weight_std={}, has_bias={}' \
|
||||||
.format(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.pad_mode, self.padding,
|
.format(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.pad_mode, self.padding,
|
||||||
|
@ -135,6 +134,7 @@ class _ConvVariational(_Conv):
|
||||||
return kl_loss
|
return kl_loss
|
||||||
|
|
||||||
def apply_variational_bias(self, inputs):
|
def apply_variational_bias(self, inputs):
|
||||||
|
"""Calculate bias."""
|
||||||
bias_posterior_tensor = self.bias_posterior("sample")
|
bias_posterior_tensor = self.bias_posterior("sample")
|
||||||
return self.bias_add(inputs, bias_posterior_tensor)
|
return self.bias_add(inputs, bias_posterior_tensor)
|
||||||
|
|
||||||
|
@ -261,6 +261,7 @@ class ConvReparam(_ConvVariational):
|
||||||
)
|
)
|
||||||
|
|
||||||
def apply_variational_weight(self, inputs):
|
def apply_variational_weight(self, inputs):
|
||||||
|
"""Calculate weight."""
|
||||||
weight_posterior_tensor = self.weight_posterior("sample")
|
weight_posterior_tensor = self.weight_posterior("sample")
|
||||||
outputs = self.conv2d(inputs, weight_posterior_tensor)
|
outputs = self.conv2d(inputs, weight_posterior_tensor)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
|
@ -78,6 +78,7 @@ class _DenseVariational(Cell):
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
|
"""Display instance object as string."""
|
||||||
s = 'in_channels={}, out_channels={}, weight_mean={}, weight_std={}, has_bias={}' \
|
s = 'in_channels={}, out_channels={}, weight_mean={}, weight_std={}, has_bias={}' \
|
||||||
.format(self.in_channels, self.out_channels, self.weight_posterior.mean,
|
.format(self.in_channels, self.out_channels, self.weight_posterior.mean,
|
||||||
self.weight_posterior.untransformed_std, self.has_bias)
|
self.weight_posterior.untransformed_std, self.has_bias)
|
||||||
|
@ -89,6 +90,7 @@ class _DenseVariational(Cell):
|
||||||
return s
|
return s
|
||||||
|
|
||||||
def apply_variational_bias(self, inputs):
|
def apply_variational_bias(self, inputs):
|
||||||
|
"""Calculate bias."""
|
||||||
bias_posterior_tensor = self.bias_posterior("sample")
|
bias_posterior_tensor = self.bias_posterior("sample")
|
||||||
return self.bias_add(inputs, bias_posterior_tensor)
|
return self.bias_add(inputs, bias_posterior_tensor)
|
||||||
|
|
||||||
|
@ -196,6 +198,7 @@ class DenseReparam(_DenseVariational):
|
||||||
)
|
)
|
||||||
|
|
||||||
def apply_variational_weight(self, inputs):
|
def apply_variational_weight(self, inputs):
|
||||||
|
"""Calculate weight."""
|
||||||
weight_posterior_tensor = self.weight_posterior("sample")
|
weight_posterior_tensor = self.weight_posterior("sample")
|
||||||
outputs = self.matmul(inputs, weight_posterior_tensor)
|
outputs = self.matmul(inputs, weight_posterior_tensor)
|
||||||
return outputs
|
return outputs
|
||||||
|
@ -292,6 +295,7 @@ class DenseLocalReparam(_DenseVariational):
|
||||||
self.normal = Normal()
|
self.normal = Normal()
|
||||||
|
|
||||||
def apply_variational_weight(self, inputs):
|
def apply_variational_weight(self, inputs):
|
||||||
|
"""Calculate weight."""
|
||||||
mean = self.matmul(inputs, self.weight_posterior("mean"))
|
mean = self.matmul(inputs, self.weight_posterior("mean"))
|
||||||
std = self.sqrt(self.matmul(self.square(inputs), self.square(self.weight_posterior("sd"))))
|
std = self.sqrt(self.matmul(self.square(inputs), self.square(self.weight_posterior("sd"))))
|
||||||
weight_posterior_affine_tensor = self.normal("sample", mean=mean, sd=std)
|
weight_posterior_affine_tensor = self.normal("sample", mean=mean, sd=std)
|
||||||
|
|
|
@ -27,7 +27,7 @@ class Bernoulli(Distribution):
|
||||||
Bernoulli Distribution.
|
Bernoulli Distribution.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
probs (float, list, numpy.ndarray, Tensor): The probability of that the outcome is 1.
|
probs (float, list, numpy.ndarray, Tensor): The probability of that the outcome is 1. Default: None.
|
||||||
seed (int): The seed used in sampling. The global seed is used if it is None. Default: None.
|
seed (int): The seed used in sampling. The global seed is used if it is None. Default: None.
|
||||||
dtype (mindspore.dtype): The type of the event samples. Default: mstype.int32.
|
dtype (mindspore.dtype): The type of the event samples. Default: mstype.int32.
|
||||||
name (str): The name of the distribution. Default: 'Bernoulli'.
|
name (str): The name of the distribution. Default: 'Bernoulli'.
|
||||||
|
@ -153,10 +153,11 @@ class Bernoulli(Distribution):
|
||||||
self.uniform = C.uniform
|
self.uniform = C.uniform
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
|
"""Display instance object as string."""
|
||||||
if self.is_scalar_batch:
|
if self.is_scalar_batch:
|
||||||
s = f'probs = {self.probs}'
|
s = 'probs = {}'.format(self.probs)
|
||||||
else:
|
else:
|
||||||
s = f'batch_shape = {self._broadcast_shape}'
|
s = 'batch_shape = {}'.format(self._broadcast_shape)
|
||||||
return s
|
return s
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -181,10 +181,11 @@ class Beta(Distribution):
|
||||||
self.lbeta = nn.LBeta()
|
self.lbeta = nn.LBeta()
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
|
"""Display instance object as string."""
|
||||||
if self.is_scalar_batch:
|
if self.is_scalar_batch:
|
||||||
s = f'concentration1 = {self._concentration1}, concentration0 = {self._concentration0}'
|
s = 'concentration1 = {}, concentration0 = {}'.format(self._concentration1, self._concentration0)
|
||||||
else:
|
else:
|
||||||
s = f'batch_shape = {self._broadcast_shape}'
|
s = 'batch_shape = {}'.format(self._broadcast_shape)
|
||||||
return s
|
return s
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -171,10 +171,11 @@ class Categorical(Distribution):
|
||||||
return self._probs
|
return self._probs
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
|
"""Display instance object as string."""
|
||||||
if self.is_scalar_batch:
|
if self.is_scalar_batch:
|
||||||
s = f'probs = {self.probs}'
|
s = 'probs = {}'.format(self.probs)
|
||||||
else:
|
else:
|
||||||
s = f'batch_shape = {self._broadcast_shape}'
|
s = 'batch_shape = {}'.format(self._broadcast_shape)
|
||||||
return s
|
return s
|
||||||
|
|
||||||
def _get_dist_type(self):
|
def _get_dist_type(self):
|
||||||
|
|
|
@ -173,10 +173,11 @@ class Cauchy(Distribution):
|
||||||
|
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
|
"""Display instance object as string."""
|
||||||
if self.is_scalar_batch:
|
if self.is_scalar_batch:
|
||||||
str_info = f'location = {self._loc}, scale = {self._scale}'
|
str_info = 'location = {}, scale = {}'.format(self._loc, self._scale)
|
||||||
else:
|
else:
|
||||||
str_info = f'batch_shape = {self._broadcast_shape}'
|
str_info = 'batch_shape = {}'.format(self._broadcast_shape)
|
||||||
return str_info
|
return str_info
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -249,6 +250,7 @@ class Cauchy(Distribution):
|
||||||
value = self.cast(value, self.dtype)
|
value = self.cast(value, self.dtype)
|
||||||
loc, scale = self._check_param_type(loc, scale)
|
loc, scale = self._check_param_type(loc, scale)
|
||||||
z = (value - loc) / scale
|
z = (value - loc) / scale
|
||||||
|
# pylint: disable=E1130
|
||||||
log_unnormalized_prob = - self.log1p(self.sq(z))
|
log_unnormalized_prob = - self.log1p(self.sq(z))
|
||||||
log_normalization = self.log(np.pi * scale)
|
log_normalization = self.log(np.pi * scale)
|
||||||
return log_unnormalized_prob - log_normalization
|
return log_unnormalized_prob - log_normalization
|
||||||
|
|
|
@ -28,7 +28,7 @@ class Exponential(Distribution):
|
||||||
Example class: Exponential Distribution.
|
Example class: Exponential Distribution.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
rate (float, list, numpy.ndarray, Tensor): The inverse scale.
|
rate (float, list, numpy.ndarray, Tensor): The inverse scale. Default: None.
|
||||||
seed (int): The seed used in sampling. The global seed is used if it is None. Default: None.
|
seed (int): The seed used in sampling. The global seed is used if it is None. Default: None.
|
||||||
dtype (mindspore.dtype): The type of the event samples. Default: mstype.float32.
|
dtype (mindspore.dtype): The type of the event samples. Default: mstype.float32.
|
||||||
name (str): The name of the distribution. Default: 'Exponential'.
|
name (str): The name of the distribution. Default: 'Exponential'.
|
||||||
|
@ -156,10 +156,11 @@ class Exponential(Distribution):
|
||||||
self.uniform = C.uniform
|
self.uniform = C.uniform
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
|
"""Display instance object as string."""
|
||||||
if self.is_scalar_batch:
|
if self.is_scalar_batch:
|
||||||
s = f'rate = {self.rate}'
|
s = 'rate = {}'.format(self.rate)
|
||||||
else:
|
else:
|
||||||
s = f'batch_shape = {self._broadcast_shape}'
|
s = 'batch_shape = {}'.format(self._broadcast_shape)
|
||||||
return s
|
return s
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -180,10 +180,11 @@ class Gamma(Distribution):
|
||||||
self.igamma = nn.IGamma()
|
self.igamma = nn.IGamma()
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
|
"""Display instance object as string."""
|
||||||
if self.is_scalar_batch:
|
if self.is_scalar_batch:
|
||||||
s = f'concentration = {self._concentration}, rate = {self._rate}'
|
s = 'concentration = {}, rate = {}'.format(self._concentration, self._rate)
|
||||||
else:
|
else:
|
||||||
s = f'batch_shape = {self._broadcast_shape}'
|
s = 'batch_shape = {}'.format(self._broadcast_shape)
|
||||||
return s
|
return s
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -165,10 +165,11 @@ class Geometric(Distribution):
|
||||||
self.uniform = C.uniform
|
self.uniform = C.uniform
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
|
"""Display instance object as string."""
|
||||||
if not self.is_scalar_batch:
|
if not self.is_scalar_batch:
|
||||||
s = f'batch_shape = {self._broadcast_shape}'
|
s = 'batch_shape = {}'.format(self._broadcast_shape)
|
||||||
else:
|
else:
|
||||||
s = f'probs = {self.probs}'
|
s = 'probs = {}'.format(self.probs)
|
||||||
return s
|
return s
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -112,10 +112,11 @@ class Gumbel(TransformedDistribution):
|
||||||
return self._scale
|
return self._scale
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
|
"""Display instance object as string."""
|
||||||
if self.is_scalar_batch:
|
if self.is_scalar_batch:
|
||||||
str_info = f'loc = {self._loc}, scale = {self._scale}'
|
str_info = 'loc = {}, scale = {}'.format(self._loc, self._scale)
|
||||||
else:
|
else:
|
||||||
str_info = f'batch_shape = {self._broadcast_shape}'
|
str_info = 'batch_shape = {}'.format(self._broadcast_shape)
|
||||||
return str_info
|
return str_info
|
||||||
|
|
||||||
def _get_dist_type(self):
|
def _get_dist_type(self):
|
||||||
|
|
|
@ -129,10 +129,11 @@ class LogNormal(msd.TransformedDistribution):
|
||||||
return loc, scale
|
return loc, scale
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
|
"""Display instance object as string."""
|
||||||
if self.is_scalar_batch:
|
if self.is_scalar_batch:
|
||||||
s = f'loc = {self.loc}, scale = {self.scale}'
|
s = 'loc = {}, scale = {}'.format(self.loc, self.scale)
|
||||||
else:
|
else:
|
||||||
s = f'batch_shape = {self.broadcast_shape}'
|
s = 'batch_shape = {}'.format(self.broadcast_shape)
|
||||||
return s
|
return s
|
||||||
|
|
||||||
def _mean(self, loc=None, scale=None):
|
def _mean(self, loc=None, scale=None):
|
||||||
|
|
|
@ -173,10 +173,11 @@ class Logistic(Distribution):
|
||||||
return self.select(too_small, too_small_value, self.select(too_large, too_large_value, y))
|
return self.select(too_small, too_small_value, self.select(too_large, too_large_value, y))
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
|
"""Display instance object as string."""
|
||||||
if self.is_scalar_batch:
|
if self.is_scalar_batch:
|
||||||
s = f'location = {self._loc}, scale = {self._scale}'
|
s = 'location = {}, scale = {}'.format(self._loc, self._scale)
|
||||||
else:
|
else:
|
||||||
s = f'batch_shape = {self._broadcast_shape}'
|
s = 'batch_shape = {}'.format(self._broadcast_shape)
|
||||||
return s
|
return s
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -291,6 +292,7 @@ class Logistic(Distribution):
|
||||||
value = self.cast(value, self.dtype)
|
value = self.cast(value, self.dtype)
|
||||||
loc, scale = self._check_param_type(loc, scale)
|
loc, scale = self._check_param_type(loc, scale)
|
||||||
z = (value - loc) / scale
|
z = (value - loc) / scale
|
||||||
|
# pylint: disable=E1130
|
||||||
return -self.softplus(-z)
|
return -self.softplus(-z)
|
||||||
|
|
||||||
def _survival_function(self, value, loc=None, scale=None):
|
def _survival_function(self, value, loc=None, scale=None):
|
||||||
|
@ -327,6 +329,7 @@ class Logistic(Distribution):
|
||||||
value = self.cast(value, self.dtype)
|
value = self.cast(value, self.dtype)
|
||||||
loc, scale = self._check_param_type(loc, scale)
|
loc, scale = self._check_param_type(loc, scale)
|
||||||
z = (value - loc) / scale
|
z = (value - loc) / scale
|
||||||
|
# pylint: disable=E1130
|
||||||
return -self.softplus(z)
|
return -self.softplus(z)
|
||||||
|
|
||||||
def _sample(self, shape=(), loc=None, scale=None):
|
def _sample(self, shape=(), loc=None, scale=None):
|
||||||
|
|
|
@ -164,10 +164,11 @@ class Normal(Distribution):
|
||||||
self.sqrt = P.Sqrt()
|
self.sqrt = P.Sqrt()
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
|
"""Display instance object as string."""
|
||||||
if self.is_scalar_batch:
|
if self.is_scalar_batch:
|
||||||
s = f'mean = {self._mean_value}, standard deviation = {self._sd_value}'
|
s = 'mean = {}, standard deviation = {}'.format(self._mean_value, self._sd_value)
|
||||||
else:
|
else:
|
||||||
s = f'batch_shape = {self._broadcast_shape}'
|
s = 'batch_shape = {}'.format(self._broadcast_shape)
|
||||||
return s
|
return s
|
||||||
|
|
||||||
def _get_dist_type(self):
|
def _get_dist_type(self):
|
||||||
|
|
|
@ -155,10 +155,11 @@ class Poisson(Distribution):
|
||||||
return self._rate
|
return self._rate
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
|
"""Display instance object as string."""
|
||||||
if self.is_scalar_batch:
|
if self.is_scalar_batch:
|
||||||
s = f'rate = {self.rate}'
|
s = 'rate = {}'.format(self.rate)
|
||||||
else:
|
else:
|
||||||
s = f'batch_shape = {self._broadcast_shape}'
|
s = 'batch_shape = {}'.format(self._broadcast_shape)
|
||||||
return s
|
return s
|
||||||
|
|
||||||
def _get_dist_type(self):
|
def _get_dist_type(self):
|
||||||
|
@ -219,6 +220,7 @@ class Poisson(Distribution):
|
||||||
safe_x = self.select(self.less(value, zeros), zeros, value)
|
safe_x = self.select(self.less(value, zeros), zeros, value)
|
||||||
y = log_rate * safe_x - self.lgamma(safe_x + 1.)
|
y = log_rate * safe_x - self.lgamma(safe_x + 1.)
|
||||||
comp = self.equal(value, safe_x)
|
comp = self.equal(value, safe_x)
|
||||||
|
# pylint: disable=E1130
|
||||||
log_unnormalized_prob = self.select(comp, y, -inf)
|
log_unnormalized_prob = self.select(comp, y, -inf)
|
||||||
log_normalization = self.exp(log_rate)
|
log_normalization = self.exp(log_rate)
|
||||||
return log_unnormalized_prob - log_normalization
|
return log_unnormalized_prob - log_normalization
|
||||||
|
|
|
@ -170,10 +170,11 @@ class Uniform(Distribution):
|
||||||
self.uniform = C.uniform
|
self.uniform = C.uniform
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
|
"""Display instance object as string."""
|
||||||
if self.is_scalar_batch:
|
if self.is_scalar_batch:
|
||||||
s = f'low = {self.low}, high = {self.high}'
|
s = 'low = {}, high = {}'.format(self.low, self.high)
|
||||||
else:
|
else:
|
||||||
s = f'batch_shape = {self._broadcast_shape}'
|
s = 'batch_shape = {}'.format(self._broadcast_shape)
|
||||||
return s
|
return s
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
Loading…
Reference in New Issue