forked from mindspore-Ecosystem/mindspore
!22040 optimize document of probability
Merge pull request !22040 from byweng/master
This commit is contained in:
commit
9db8f3221f
|
@ -147,6 +147,7 @@ class Bijector(Cell):
|
|||
return (shape_tensor + dist_shape_tensor).shape
|
||||
|
||||
def shape_mapping(self, shape):
|
||||
"""Map shape."""
|
||||
return self._shape_mapping(shape)
|
||||
|
||||
def _add_parameter(self, value, name):
|
||||
|
@ -161,7 +162,7 @@ class Bijector(Cell):
|
|||
self.common_dtype = None
|
||||
# cast value to a tensor if it is not None
|
||||
if isinstance(value, bool) or value is None:
|
||||
raise TypeError(f"{name} cannot be type {type(value)}")
|
||||
raise TypeError("{} cannot be type {}".format(name, type(value)))
|
||||
value_t = Tensor(value)
|
||||
# if the bijector's dtype is not specified
|
||||
if self.dtype is None:
|
||||
|
|
|
@ -57,8 +57,9 @@ class Exp(PowerTransform):
|
|||
super(Exp, self).__init__(name=name)
|
||||
|
||||
def extend_repr(self):
|
||||
"""Display instance object as string."""
|
||||
if self.is_scalar_batch:
|
||||
str_info = 'exp'
|
||||
else:
|
||||
str_info = f'batch_shape = {self.batch_shape}'
|
||||
str_info = 'batch_shape = {}'.format(self.batch_shape)
|
||||
return str_info
|
||||
|
|
|
@ -28,7 +28,7 @@ class GumbelCDF(Bijector):
|
|||
Y = \exp(-\exp(\frac{-(X - loc)}{scale}))
|
||||
|
||||
Args:
|
||||
loc (float, list, numpy.ndarray, Tensor): The location. Default: 0..
|
||||
loc (float, list, numpy.ndarray, Tensor): The location. Default: 0.0.
|
||||
scale (float, list, numpy.ndarray, Tensor): The scale. Default: 1.0.
|
||||
name (str): The name of the Bijector. Default: 'GumbelCDF'.
|
||||
|
||||
|
@ -101,10 +101,11 @@ class GumbelCDF(Bijector):
|
|||
return self._scale
|
||||
|
||||
def extend_repr(self):
|
||||
"""Display instance object as string."""
|
||||
if self.is_scalar_batch:
|
||||
str_info = f'loc = {self.loc}, scale = {self.scale}'
|
||||
str_info = 'loc = {}, scale = {}'.format(self.loc, self.scale)
|
||||
else:
|
||||
str_info = f'batch_shape = {self.batch_shape}'
|
||||
str_info = 'batch_shape = {}'.format(self.batch_shape)
|
||||
return str_info
|
||||
|
||||
def _forward(self, x):
|
||||
|
@ -112,9 +113,12 @@ class GumbelCDF(Bijector):
|
|||
loc_local = self.cast_param_by_value(x, self.loc)
|
||||
scale_local = self.cast_param_by_value(x, self.scale)
|
||||
z = (x - loc_local) / scale_local
|
||||
# pylint: disable=E1130
|
||||
return self.exp(-self.exp(-z))
|
||||
|
||||
def _inverse(self, y):
|
||||
# pylint false positive
|
||||
# pylint: disable=E1130
|
||||
y = self._check_value_dtype(y)
|
||||
loc_local = self.cast_param_by_value(y, self.loc)
|
||||
scale_local = self.cast_param_by_value(y, self.scale)
|
||||
|
|
|
@ -23,7 +23,8 @@ class Invert(Bijector):
|
|||
|
||||
Args:
|
||||
bijector (Bijector): Base Bijector.
|
||||
name (str): The name of the Bijector. Default: 'Invert' + bijector.name.
|
||||
name (str): The name of the Bijector. Default: "". When name is set to "", it is actually
|
||||
'Invert' + bijector.name.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
|
@ -67,16 +68,29 @@ class Invert(Bijector):
|
|||
|
||||
@property
|
||||
def bijector(self):
|
||||
"""Return base bijector."""
|
||||
return self._bijector
|
||||
|
||||
def inverse(self, y):
|
||||
"""
|
||||
Forward transformation: transform the input value to another distribution.
|
||||
"""
|
||||
return self.bijector("forward", y)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Inverse transformation: transform the input value back to the original distribution.
|
||||
"""
|
||||
return self.bijector("inverse", x)
|
||||
|
||||
def inverse_log_jacobian(self, y):
|
||||
"""
|
||||
Logarithm of the derivative of the forward transformation.
|
||||
"""
|
||||
return self.bijector("forward_log_jacobian", y)
|
||||
|
||||
def forward_log_jacobian(self, x):
|
||||
"""
|
||||
Logarithm of the derivative of the inverse transformation.
|
||||
"""
|
||||
return self.bijector("inverse_log_jacobian", x)
|
||||
|
|
|
@ -95,13 +95,13 @@ class PowerTransform(Bijector):
|
|||
return self._power
|
||||
|
||||
def extend_repr(self):
|
||||
"""Display instance object as string."""
|
||||
if self.is_scalar_batch:
|
||||
str_info = f'power = {self.power}'
|
||||
str_info = 'power = {}'.format(self.power)
|
||||
else:
|
||||
str_info = f'batch_shape = {self.batch_shape}'
|
||||
str_info = 'batch_shape = {}'.format(self.batch_shape)
|
||||
return str_info
|
||||
|
||||
|
||||
def _forward(self, x):
|
||||
"""
|
||||
Evaluate the forward mapping.
|
||||
|
|
|
@ -101,10 +101,11 @@ class ScalarAffine(Bijector):
|
|||
return self._shift
|
||||
|
||||
def extend_repr(self):
|
||||
"""Display instance object as string."""
|
||||
if self.is_scalar_batch:
|
||||
str_info = f'scale = {self.scale}, shift = {self.shift}'
|
||||
str_info = 'scale = {}, shift = {}'.format(self.scale, self.shift)
|
||||
else:
|
||||
str_info = f'batch_shape = {self.batch_shape}'
|
||||
str_info = 'batch_shape = {}'.format(self.batch_shape)
|
||||
return str_info
|
||||
|
||||
def _forward(self, x):
|
||||
|
|
|
@ -122,6 +122,7 @@ class Softplus(Bijector):
|
|||
ones = self.fill(self.dtypeop(x), self.shape(x), 1.0)
|
||||
too_small_or_too_large = self.logicalor(too_small, too_large)
|
||||
x = self.select(too_small_or_too_large, ones, x)
|
||||
# pylint: disable=E1130
|
||||
y = x + self.log(self.abs(self.expm1(-x)))
|
||||
return self.select(too_small, too_small_value, self.select(too_large, too_large_value, y))
|
||||
|
||||
|
@ -130,10 +131,11 @@ class Softplus(Bijector):
|
|||
return self._sharpness
|
||||
|
||||
def extend_repr(self):
|
||||
"""Display instance object as string."""
|
||||
if self.is_scalar_batch:
|
||||
str_info = f'sharpness = {self.sharpness}'
|
||||
str_info = 'sharpness = {}'.format(self.sharpness)
|
||||
else:
|
||||
str_info = f'batch_shape = {self.batch_shape}'
|
||||
str_info = 'batch_shape = {}'.format(self.batch_shape)
|
||||
return str_info
|
||||
|
||||
def _forward(self, x):
|
||||
|
|
|
@ -72,9 +72,7 @@ class _ConvVariational(_Conv):
|
|||
self.group = group
|
||||
self.has_bias = has_bias
|
||||
|
||||
# distribution trainable parameters
|
||||
self.shape = [self.out_channels,
|
||||
self.in_channels // self.group, *self.kernel_size]
|
||||
self.shape = [self.out_channels, self.in_channels // self.group, *self.kernel_size]
|
||||
|
||||
self.weight.requires_grad = False
|
||||
self.weight_prior = check_prior(weight_prior_fn, "weight_prior_fn")
|
||||
|
@ -108,6 +106,7 @@ class _ConvVariational(_Conv):
|
|||
return outputs
|
||||
|
||||
def extend_repr(self):
|
||||
"""Display instance object as string."""
|
||||
s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, pad_mode={}, ' \
|
||||
'padding={}, dilation={}, group={}, weight_mean={}, weight_std={}, has_bias={}' \
|
||||
.format(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.pad_mode, self.padding,
|
||||
|
@ -135,6 +134,7 @@ class _ConvVariational(_Conv):
|
|||
return kl_loss
|
||||
|
||||
def apply_variational_bias(self, inputs):
|
||||
"""Calculate bias."""
|
||||
bias_posterior_tensor = self.bias_posterior("sample")
|
||||
return self.bias_add(inputs, bias_posterior_tensor)
|
||||
|
||||
|
@ -261,6 +261,7 @@ class ConvReparam(_ConvVariational):
|
|||
)
|
||||
|
||||
def apply_variational_weight(self, inputs):
|
||||
"""Calculate weight."""
|
||||
weight_posterior_tensor = self.weight_posterior("sample")
|
||||
outputs = self.conv2d(inputs, weight_posterior_tensor)
|
||||
return outputs
|
||||
|
|
|
@ -78,6 +78,7 @@ class _DenseVariational(Cell):
|
|||
return outputs
|
||||
|
||||
def extend_repr(self):
|
||||
"""Display instance object as string."""
|
||||
s = 'in_channels={}, out_channels={}, weight_mean={}, weight_std={}, has_bias={}' \
|
||||
.format(self.in_channels, self.out_channels, self.weight_posterior.mean,
|
||||
self.weight_posterior.untransformed_std, self.has_bias)
|
||||
|
@ -89,6 +90,7 @@ class _DenseVariational(Cell):
|
|||
return s
|
||||
|
||||
def apply_variational_bias(self, inputs):
|
||||
"""Calculate bias."""
|
||||
bias_posterior_tensor = self.bias_posterior("sample")
|
||||
return self.bias_add(inputs, bias_posterior_tensor)
|
||||
|
||||
|
@ -196,6 +198,7 @@ class DenseReparam(_DenseVariational):
|
|||
)
|
||||
|
||||
def apply_variational_weight(self, inputs):
|
||||
"""Calculate weight."""
|
||||
weight_posterior_tensor = self.weight_posterior("sample")
|
||||
outputs = self.matmul(inputs, weight_posterior_tensor)
|
||||
return outputs
|
||||
|
@ -292,6 +295,7 @@ class DenseLocalReparam(_DenseVariational):
|
|||
self.normal = Normal()
|
||||
|
||||
def apply_variational_weight(self, inputs):
|
||||
"""Calculate weight."""
|
||||
mean = self.matmul(inputs, self.weight_posterior("mean"))
|
||||
std = self.sqrt(self.matmul(self.square(inputs), self.square(self.weight_posterior("sd"))))
|
||||
weight_posterior_affine_tensor = self.normal("sample", mean=mean, sd=std)
|
||||
|
|
|
@ -27,7 +27,7 @@ class Bernoulli(Distribution):
|
|||
Bernoulli Distribution.
|
||||
|
||||
Args:
|
||||
probs (float, list, numpy.ndarray, Tensor): The probability of that the outcome is 1.
|
||||
probs (float, list, numpy.ndarray, Tensor): The probability of that the outcome is 1. Default: None.
|
||||
seed (int): The seed used in sampling. The global seed is used if it is None. Default: None.
|
||||
dtype (mindspore.dtype): The type of the event samples. Default: mstype.int32.
|
||||
name (str): The name of the distribution. Default: 'Bernoulli'.
|
||||
|
@ -153,10 +153,11 @@ class Bernoulli(Distribution):
|
|||
self.uniform = C.uniform
|
||||
|
||||
def extend_repr(self):
|
||||
"""Display instance object as string."""
|
||||
if self.is_scalar_batch:
|
||||
s = f'probs = {self.probs}'
|
||||
s = 'probs = {}'.format(self.probs)
|
||||
else:
|
||||
s = f'batch_shape = {self._broadcast_shape}'
|
||||
s = 'batch_shape = {}'.format(self._broadcast_shape)
|
||||
return s
|
||||
|
||||
@property
|
||||
|
|
|
@ -181,10 +181,11 @@ class Beta(Distribution):
|
|||
self.lbeta = nn.LBeta()
|
||||
|
||||
def extend_repr(self):
|
||||
"""Display instance object as string."""
|
||||
if self.is_scalar_batch:
|
||||
s = f'concentration1 = {self._concentration1}, concentration0 = {self._concentration0}'
|
||||
s = 'concentration1 = {}, concentration0 = {}'.format(self._concentration1, self._concentration0)
|
||||
else:
|
||||
s = f'batch_shape = {self._broadcast_shape}'
|
||||
s = 'batch_shape = {}'.format(self._broadcast_shape)
|
||||
return s
|
||||
|
||||
@property
|
||||
|
|
|
@ -171,10 +171,11 @@ class Categorical(Distribution):
|
|||
return self._probs
|
||||
|
||||
def extend_repr(self):
|
||||
"""Display instance object as string."""
|
||||
if self.is_scalar_batch:
|
||||
s = f'probs = {self.probs}'
|
||||
s = 'probs = {}'.format(self.probs)
|
||||
else:
|
||||
s = f'batch_shape = {self._broadcast_shape}'
|
||||
s = 'batch_shape = {}'.format(self._broadcast_shape)
|
||||
return s
|
||||
|
||||
def _get_dist_type(self):
|
||||
|
|
|
@ -173,10 +173,11 @@ class Cauchy(Distribution):
|
|||
|
||||
|
||||
def extend_repr(self):
|
||||
"""Display instance object as string."""
|
||||
if self.is_scalar_batch:
|
||||
str_info = f'location = {self._loc}, scale = {self._scale}'
|
||||
str_info = 'location = {}, scale = {}'.format(self._loc, self._scale)
|
||||
else:
|
||||
str_info = f'batch_shape = {self._broadcast_shape}'
|
||||
str_info = 'batch_shape = {}'.format(self._broadcast_shape)
|
||||
return str_info
|
||||
|
||||
@property
|
||||
|
@ -249,6 +250,7 @@ class Cauchy(Distribution):
|
|||
value = self.cast(value, self.dtype)
|
||||
loc, scale = self._check_param_type(loc, scale)
|
||||
z = (value - loc) / scale
|
||||
# pylint: disable=E1130
|
||||
log_unnormalized_prob = - self.log1p(self.sq(z))
|
||||
log_normalization = self.log(np.pi * scale)
|
||||
return log_unnormalized_prob - log_normalization
|
||||
|
|
|
@ -28,7 +28,7 @@ class Exponential(Distribution):
|
|||
Example class: Exponential Distribution.
|
||||
|
||||
Args:
|
||||
rate (float, list, numpy.ndarray, Tensor): The inverse scale.
|
||||
rate (float, list, numpy.ndarray, Tensor): The inverse scale. Default: None.
|
||||
seed (int): The seed used in sampling. The global seed is used if it is None. Default: None.
|
||||
dtype (mindspore.dtype): The type of the event samples. Default: mstype.float32.
|
||||
name (str): The name of the distribution. Default: 'Exponential'.
|
||||
|
@ -156,10 +156,11 @@ class Exponential(Distribution):
|
|||
self.uniform = C.uniform
|
||||
|
||||
def extend_repr(self):
|
||||
"""Display instance object as string."""
|
||||
if self.is_scalar_batch:
|
||||
s = f'rate = {self.rate}'
|
||||
s = 'rate = {}'.format(self.rate)
|
||||
else:
|
||||
s = f'batch_shape = {self._broadcast_shape}'
|
||||
s = 'batch_shape = {}'.format(self._broadcast_shape)
|
||||
return s
|
||||
|
||||
@property
|
||||
|
|
|
@ -180,10 +180,11 @@ class Gamma(Distribution):
|
|||
self.igamma = nn.IGamma()
|
||||
|
||||
def extend_repr(self):
|
||||
"""Display instance object as string."""
|
||||
if self.is_scalar_batch:
|
||||
s = f'concentration = {self._concentration}, rate = {self._rate}'
|
||||
s = 'concentration = {}, rate = {}'.format(self._concentration, self._rate)
|
||||
else:
|
||||
s = f'batch_shape = {self._broadcast_shape}'
|
||||
s = 'batch_shape = {}'.format(self._broadcast_shape)
|
||||
return s
|
||||
|
||||
@property
|
||||
|
|
|
@ -165,10 +165,11 @@ class Geometric(Distribution):
|
|||
self.uniform = C.uniform
|
||||
|
||||
def extend_repr(self):
|
||||
"""Display instance object as string."""
|
||||
if not self.is_scalar_batch:
|
||||
s = f'batch_shape = {self._broadcast_shape}'
|
||||
s = 'batch_shape = {}'.format(self._broadcast_shape)
|
||||
else:
|
||||
s = f'probs = {self.probs}'
|
||||
s = 'probs = {}'.format(self.probs)
|
||||
return s
|
||||
|
||||
@property
|
||||
|
|
|
@ -112,10 +112,11 @@ class Gumbel(TransformedDistribution):
|
|||
return self._scale
|
||||
|
||||
def extend_repr(self):
|
||||
"""Display instance object as string."""
|
||||
if self.is_scalar_batch:
|
||||
str_info = f'loc = {self._loc}, scale = {self._scale}'
|
||||
str_info = 'loc = {}, scale = {}'.format(self._loc, self._scale)
|
||||
else:
|
||||
str_info = f'batch_shape = {self._broadcast_shape}'
|
||||
str_info = 'batch_shape = {}'.format(self._broadcast_shape)
|
||||
return str_info
|
||||
|
||||
def _get_dist_type(self):
|
||||
|
|
|
@ -129,10 +129,11 @@ class LogNormal(msd.TransformedDistribution):
|
|||
return loc, scale
|
||||
|
||||
def extend_repr(self):
|
||||
"""Display instance object as string."""
|
||||
if self.is_scalar_batch:
|
||||
s = f'loc = {self.loc}, scale = {self.scale}'
|
||||
s = 'loc = {}, scale = {}'.format(self.loc, self.scale)
|
||||
else:
|
||||
s = f'batch_shape = {self.broadcast_shape}'
|
||||
s = 'batch_shape = {}'.format(self.broadcast_shape)
|
||||
return s
|
||||
|
||||
def _mean(self, loc=None, scale=None):
|
||||
|
|
|
@ -173,10 +173,11 @@ class Logistic(Distribution):
|
|||
return self.select(too_small, too_small_value, self.select(too_large, too_large_value, y))
|
||||
|
||||
def extend_repr(self):
|
||||
"""Display instance object as string."""
|
||||
if self.is_scalar_batch:
|
||||
s = f'location = {self._loc}, scale = {self._scale}'
|
||||
s = 'location = {}, scale = {}'.format(self._loc, self._scale)
|
||||
else:
|
||||
s = f'batch_shape = {self._broadcast_shape}'
|
||||
s = 'batch_shape = {}'.format(self._broadcast_shape)
|
||||
return s
|
||||
|
||||
@property
|
||||
|
@ -291,6 +292,7 @@ class Logistic(Distribution):
|
|||
value = self.cast(value, self.dtype)
|
||||
loc, scale = self._check_param_type(loc, scale)
|
||||
z = (value - loc) / scale
|
||||
# pylint: disable=E1130
|
||||
return -self.softplus(-z)
|
||||
|
||||
def _survival_function(self, value, loc=None, scale=None):
|
||||
|
@ -327,6 +329,7 @@ class Logistic(Distribution):
|
|||
value = self.cast(value, self.dtype)
|
||||
loc, scale = self._check_param_type(loc, scale)
|
||||
z = (value - loc) / scale
|
||||
# pylint: disable=E1130
|
||||
return -self.softplus(z)
|
||||
|
||||
def _sample(self, shape=(), loc=None, scale=None):
|
||||
|
|
|
@ -164,10 +164,11 @@ class Normal(Distribution):
|
|||
self.sqrt = P.Sqrt()
|
||||
|
||||
def extend_repr(self):
|
||||
"""Display instance object as string."""
|
||||
if self.is_scalar_batch:
|
||||
s = f'mean = {self._mean_value}, standard deviation = {self._sd_value}'
|
||||
s = 'mean = {}, standard deviation = {}'.format(self._mean_value, self._sd_value)
|
||||
else:
|
||||
s = f'batch_shape = {self._broadcast_shape}'
|
||||
s = 'batch_shape = {}'.format(self._broadcast_shape)
|
||||
return s
|
||||
|
||||
def _get_dist_type(self):
|
||||
|
|
|
@ -155,10 +155,11 @@ class Poisson(Distribution):
|
|||
return self._rate
|
||||
|
||||
def extend_repr(self):
|
||||
"""Display instance object as string."""
|
||||
if self.is_scalar_batch:
|
||||
s = f'rate = {self.rate}'
|
||||
s = 'rate = {}'.format(self.rate)
|
||||
else:
|
||||
s = f'batch_shape = {self._broadcast_shape}'
|
||||
s = 'batch_shape = {}'.format(self._broadcast_shape)
|
||||
return s
|
||||
|
||||
def _get_dist_type(self):
|
||||
|
@ -219,6 +220,7 @@ class Poisson(Distribution):
|
|||
safe_x = self.select(self.less(value, zeros), zeros, value)
|
||||
y = log_rate * safe_x - self.lgamma(safe_x + 1.)
|
||||
comp = self.equal(value, safe_x)
|
||||
# pylint: disable=E1130
|
||||
log_unnormalized_prob = self.select(comp, y, -inf)
|
||||
log_normalization = self.exp(log_rate)
|
||||
return log_unnormalized_prob - log_normalization
|
||||
|
|
|
@ -170,10 +170,11 @@ class Uniform(Distribution):
|
|||
self.uniform = C.uniform
|
||||
|
||||
def extend_repr(self):
|
||||
"""Display instance object as string."""
|
||||
if self.is_scalar_batch:
|
||||
s = f'low = {self.low}, high = {self.high}'
|
||||
s = 'low = {}, high = {}'.format(self.low, self.high)
|
||||
else:
|
||||
s = f'batch_shape = {self._broadcast_shape}'
|
||||
s = 'batch_shape = {}'.format(self._broadcast_shape)
|
||||
return s
|
||||
|
||||
@property
|
||||
|
|
Loading…
Reference in New Issue