!22040 optimize document of probability

Merge pull request !22040 from byweng/master
This commit is contained in:
i-robot 2021-08-20 07:09:39 +00:00 committed by Gitee
commit 9db8f3221f
22 changed files with 89 additions and 44 deletions

View File

@ -147,6 +147,7 @@ class Bijector(Cell):
return (shape_tensor + dist_shape_tensor).shape return (shape_tensor + dist_shape_tensor).shape
def shape_mapping(self, shape): def shape_mapping(self, shape):
"""Map shape."""
return self._shape_mapping(shape) return self._shape_mapping(shape)
def _add_parameter(self, value, name): def _add_parameter(self, value, name):
@ -161,7 +162,7 @@ class Bijector(Cell):
self.common_dtype = None self.common_dtype = None
# cast value to a tensor if it is not None # cast value to a tensor if it is not None
if isinstance(value, bool) or value is None: if isinstance(value, bool) or value is None:
raise TypeError(f"{name} cannot be type {type(value)}") raise TypeError("{} cannot be type {}".format(name, type(value)))
value_t = Tensor(value) value_t = Tensor(value)
# if the bijector's dtype is not specified # if the bijector's dtype is not specified
if self.dtype is None: if self.dtype is None:

View File

@ -57,8 +57,9 @@ class Exp(PowerTransform):
super(Exp, self).__init__(name=name) super(Exp, self).__init__(name=name)
def extend_repr(self): def extend_repr(self):
"""Display instance object as string."""
if self.is_scalar_batch: if self.is_scalar_batch:
str_info = 'exp' str_info = 'exp'
else: else:
str_info = f'batch_shape = {self.batch_shape}' str_info = 'batch_shape = {}'.format(self.batch_shape)
return str_info return str_info

View File

@ -28,7 +28,7 @@ class GumbelCDF(Bijector):
Y = \exp(-\exp(\frac{-(X - loc)}{scale})) Y = \exp(-\exp(\frac{-(X - loc)}{scale}))
Args: Args:
loc (float, list, numpy.ndarray, Tensor): The location. Default: 0.. loc (float, list, numpy.ndarray, Tensor): The location. Default: 0.0.
scale (float, list, numpy.ndarray, Tensor): The scale. Default: 1.0. scale (float, list, numpy.ndarray, Tensor): The scale. Default: 1.0.
name (str): The name of the Bijector. Default: 'GumbelCDF'. name (str): The name of the Bijector. Default: 'GumbelCDF'.
@ -101,10 +101,11 @@ class GumbelCDF(Bijector):
return self._scale return self._scale
def extend_repr(self): def extend_repr(self):
"""Display instance object as string."""
if self.is_scalar_batch: if self.is_scalar_batch:
str_info = f'loc = {self.loc}, scale = {self.scale}' str_info = 'loc = {}, scale = {}'.format(self.loc, self.scale)
else: else:
str_info = f'batch_shape = {self.batch_shape}' str_info = 'batch_shape = {}'.format(self.batch_shape)
return str_info return str_info
def _forward(self, x): def _forward(self, x):
@ -112,9 +113,12 @@ class GumbelCDF(Bijector):
loc_local = self.cast_param_by_value(x, self.loc) loc_local = self.cast_param_by_value(x, self.loc)
scale_local = self.cast_param_by_value(x, self.scale) scale_local = self.cast_param_by_value(x, self.scale)
z = (x - loc_local) / scale_local z = (x - loc_local) / scale_local
# pylint: disable=E1130
return self.exp(-self.exp(-z)) return self.exp(-self.exp(-z))
def _inverse(self, y): def _inverse(self, y):
# pylint false positive
# pylint: disable=E1130
y = self._check_value_dtype(y) y = self._check_value_dtype(y)
loc_local = self.cast_param_by_value(y, self.loc) loc_local = self.cast_param_by_value(y, self.loc)
scale_local = self.cast_param_by_value(y, self.scale) scale_local = self.cast_param_by_value(y, self.scale)

View File

@ -23,7 +23,8 @@ class Invert(Bijector):
Args: Args:
bijector (Bijector): Base Bijector. bijector (Bijector): Base Bijector.
name (str): The name of the Bijector. Default: 'Invert' + bijector.name. name (str): The name of the Bijector. Default: "". When name is set to "", it is actually
'Invert' + bijector.name.
Supported Platforms: Supported Platforms:
``Ascend`` ``GPU`` ``Ascend`` ``GPU``
@ -67,16 +68,29 @@ class Invert(Bijector):
@property @property
def bijector(self): def bijector(self):
"""Return base bijector."""
return self._bijector return self._bijector
def inverse(self, y): def inverse(self, y):
"""
Forward transformation: transform the input value to another distribution.
"""
return self.bijector("forward", y) return self.bijector("forward", y)
def forward(self, x): def forward(self, x):
"""
Inverse transformation: transform the input value back to the original distribution.
"""
return self.bijector("inverse", x) return self.bijector("inverse", x)
def inverse_log_jacobian(self, y): def inverse_log_jacobian(self, y):
"""
Logarithm of the derivative of the forward transformation.
"""
return self.bijector("forward_log_jacobian", y) return self.bijector("forward_log_jacobian", y)
def forward_log_jacobian(self, x): def forward_log_jacobian(self, x):
"""
Logarithm of the derivative of the inverse transformation.
"""
return self.bijector("inverse_log_jacobian", x) return self.bijector("inverse_log_jacobian", x)

View File

@ -95,13 +95,13 @@ class PowerTransform(Bijector):
return self._power return self._power
def extend_repr(self): def extend_repr(self):
"""Display instance object as string."""
if self.is_scalar_batch: if self.is_scalar_batch:
str_info = f'power = {self.power}' str_info = 'power = {}'.format(self.power)
else: else:
str_info = f'batch_shape = {self.batch_shape}' str_info = 'batch_shape = {}'.format(self.batch_shape)
return str_info return str_info
def _forward(self, x): def _forward(self, x):
""" """
Evaluate the forward mapping. Evaluate the forward mapping.

View File

@ -101,10 +101,11 @@ class ScalarAffine(Bijector):
return self._shift return self._shift
def extend_repr(self): def extend_repr(self):
"""Display instance object as string."""
if self.is_scalar_batch: if self.is_scalar_batch:
str_info = f'scale = {self.scale}, shift = {self.shift}' str_info = 'scale = {}, shift = {}'.format(self.scale, self.shift)
else: else:
str_info = f'batch_shape = {self.batch_shape}' str_info = 'batch_shape = {}'.format(self.batch_shape)
return str_info return str_info
def _forward(self, x): def _forward(self, x):

View File

@ -122,6 +122,7 @@ class Softplus(Bijector):
ones = self.fill(self.dtypeop(x), self.shape(x), 1.0) ones = self.fill(self.dtypeop(x), self.shape(x), 1.0)
too_small_or_too_large = self.logicalor(too_small, too_large) too_small_or_too_large = self.logicalor(too_small, too_large)
x = self.select(too_small_or_too_large, ones, x) x = self.select(too_small_or_too_large, ones, x)
# pylint: disable=E1130
y = x + self.log(self.abs(self.expm1(-x))) y = x + self.log(self.abs(self.expm1(-x)))
return self.select(too_small, too_small_value, self.select(too_large, too_large_value, y)) return self.select(too_small, too_small_value, self.select(too_large, too_large_value, y))
@ -130,10 +131,11 @@ class Softplus(Bijector):
return self._sharpness return self._sharpness
def extend_repr(self): def extend_repr(self):
"""Display instance object as string."""
if self.is_scalar_batch: if self.is_scalar_batch:
str_info = f'sharpness = {self.sharpness}' str_info = 'sharpness = {}'.format(self.sharpness)
else: else:
str_info = f'batch_shape = {self.batch_shape}' str_info = 'batch_shape = {}'.format(self.batch_shape)
return str_info return str_info
def _forward(self, x): def _forward(self, x):

View File

@ -72,9 +72,7 @@ class _ConvVariational(_Conv):
self.group = group self.group = group
self.has_bias = has_bias self.has_bias = has_bias
# distribution trainable parameters self.shape = [self.out_channels, self.in_channels // self.group, *self.kernel_size]
self.shape = [self.out_channels,
self.in_channels // self.group, *self.kernel_size]
self.weight.requires_grad = False self.weight.requires_grad = False
self.weight_prior = check_prior(weight_prior_fn, "weight_prior_fn") self.weight_prior = check_prior(weight_prior_fn, "weight_prior_fn")
@ -108,6 +106,7 @@ class _ConvVariational(_Conv):
return outputs return outputs
def extend_repr(self): def extend_repr(self):
"""Display instance object as string."""
s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, pad_mode={}, ' \ s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, pad_mode={}, ' \
'padding={}, dilation={}, group={}, weight_mean={}, weight_std={}, has_bias={}' \ 'padding={}, dilation={}, group={}, weight_mean={}, weight_std={}, has_bias={}' \
.format(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.pad_mode, self.padding, .format(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.pad_mode, self.padding,
@ -135,6 +134,7 @@ class _ConvVariational(_Conv):
return kl_loss return kl_loss
def apply_variational_bias(self, inputs): def apply_variational_bias(self, inputs):
"""Calculate bias."""
bias_posterior_tensor = self.bias_posterior("sample") bias_posterior_tensor = self.bias_posterior("sample")
return self.bias_add(inputs, bias_posterior_tensor) return self.bias_add(inputs, bias_posterior_tensor)
@ -261,6 +261,7 @@ class ConvReparam(_ConvVariational):
) )
def apply_variational_weight(self, inputs): def apply_variational_weight(self, inputs):
"""Calculate weight."""
weight_posterior_tensor = self.weight_posterior("sample") weight_posterior_tensor = self.weight_posterior("sample")
outputs = self.conv2d(inputs, weight_posterior_tensor) outputs = self.conv2d(inputs, weight_posterior_tensor)
return outputs return outputs

View File

@ -78,6 +78,7 @@ class _DenseVariational(Cell):
return outputs return outputs
def extend_repr(self): def extend_repr(self):
"""Display instance object as string."""
s = 'in_channels={}, out_channels={}, weight_mean={}, weight_std={}, has_bias={}' \ s = 'in_channels={}, out_channels={}, weight_mean={}, weight_std={}, has_bias={}' \
.format(self.in_channels, self.out_channels, self.weight_posterior.mean, .format(self.in_channels, self.out_channels, self.weight_posterior.mean,
self.weight_posterior.untransformed_std, self.has_bias) self.weight_posterior.untransformed_std, self.has_bias)
@ -89,6 +90,7 @@ class _DenseVariational(Cell):
return s return s
def apply_variational_bias(self, inputs): def apply_variational_bias(self, inputs):
"""Calculate bias."""
bias_posterior_tensor = self.bias_posterior("sample") bias_posterior_tensor = self.bias_posterior("sample")
return self.bias_add(inputs, bias_posterior_tensor) return self.bias_add(inputs, bias_posterior_tensor)
@ -196,6 +198,7 @@ class DenseReparam(_DenseVariational):
) )
def apply_variational_weight(self, inputs): def apply_variational_weight(self, inputs):
"""Calculate weight."""
weight_posterior_tensor = self.weight_posterior("sample") weight_posterior_tensor = self.weight_posterior("sample")
outputs = self.matmul(inputs, weight_posterior_tensor) outputs = self.matmul(inputs, weight_posterior_tensor)
return outputs return outputs
@ -292,6 +295,7 @@ class DenseLocalReparam(_DenseVariational):
self.normal = Normal() self.normal = Normal()
def apply_variational_weight(self, inputs): def apply_variational_weight(self, inputs):
"""Calculate weight."""
mean = self.matmul(inputs, self.weight_posterior("mean")) mean = self.matmul(inputs, self.weight_posterior("mean"))
std = self.sqrt(self.matmul(self.square(inputs), self.square(self.weight_posterior("sd")))) std = self.sqrt(self.matmul(self.square(inputs), self.square(self.weight_posterior("sd"))))
weight_posterior_affine_tensor = self.normal("sample", mean=mean, sd=std) weight_posterior_affine_tensor = self.normal("sample", mean=mean, sd=std)

View File

@ -27,7 +27,7 @@ class Bernoulli(Distribution):
Bernoulli Distribution. Bernoulli Distribution.
Args: Args:
probs (float, list, numpy.ndarray, Tensor): The probability of that the outcome is 1. probs (float, list, numpy.ndarray, Tensor): The probability of that the outcome is 1. Default: None.
seed (int): The seed used in sampling. The global seed is used if it is None. Default: None. seed (int): The seed used in sampling. The global seed is used if it is None. Default: None.
dtype (mindspore.dtype): The type of the event samples. Default: mstype.int32. dtype (mindspore.dtype): The type of the event samples. Default: mstype.int32.
name (str): The name of the distribution. Default: 'Bernoulli'. name (str): The name of the distribution. Default: 'Bernoulli'.
@ -153,10 +153,11 @@ class Bernoulli(Distribution):
self.uniform = C.uniform self.uniform = C.uniform
def extend_repr(self): def extend_repr(self):
"""Display instance object as string."""
if self.is_scalar_batch: if self.is_scalar_batch:
s = f'probs = {self.probs}' s = 'probs = {}'.format(self.probs)
else: else:
s = f'batch_shape = {self._broadcast_shape}' s = 'batch_shape = {}'.format(self._broadcast_shape)
return s return s
@property @property

View File

@ -181,10 +181,11 @@ class Beta(Distribution):
self.lbeta = nn.LBeta() self.lbeta = nn.LBeta()
def extend_repr(self): def extend_repr(self):
"""Display instance object as string."""
if self.is_scalar_batch: if self.is_scalar_batch:
s = f'concentration1 = {self._concentration1}, concentration0 = {self._concentration0}' s = 'concentration1 = {}, concentration0 = {}'.format(self._concentration1, self._concentration0)
else: else:
s = f'batch_shape = {self._broadcast_shape}' s = 'batch_shape = {}'.format(self._broadcast_shape)
return s return s
@property @property

View File

@ -171,10 +171,11 @@ class Categorical(Distribution):
return self._probs return self._probs
def extend_repr(self): def extend_repr(self):
"""Display instance object as string."""
if self.is_scalar_batch: if self.is_scalar_batch:
s = f'probs = {self.probs}' s = 'probs = {}'.format(self.probs)
else: else:
s = f'batch_shape = {self._broadcast_shape}' s = 'batch_shape = {}'.format(self._broadcast_shape)
return s return s
def _get_dist_type(self): def _get_dist_type(self):

View File

@ -173,10 +173,11 @@ class Cauchy(Distribution):
def extend_repr(self): def extend_repr(self):
"""Display instance object as string."""
if self.is_scalar_batch: if self.is_scalar_batch:
str_info = f'location = {self._loc}, scale = {self._scale}' str_info = 'location = {}, scale = {}'.format(self._loc, self._scale)
else: else:
str_info = f'batch_shape = {self._broadcast_shape}' str_info = 'batch_shape = {}'.format(self._broadcast_shape)
return str_info return str_info
@property @property
@ -249,6 +250,7 @@ class Cauchy(Distribution):
value = self.cast(value, self.dtype) value = self.cast(value, self.dtype)
loc, scale = self._check_param_type(loc, scale) loc, scale = self._check_param_type(loc, scale)
z = (value - loc) / scale z = (value - loc) / scale
# pylint: disable=E1130
log_unnormalized_prob = - self.log1p(self.sq(z)) log_unnormalized_prob = - self.log1p(self.sq(z))
log_normalization = self.log(np.pi * scale) log_normalization = self.log(np.pi * scale)
return log_unnormalized_prob - log_normalization return log_unnormalized_prob - log_normalization

View File

@ -28,7 +28,7 @@ class Exponential(Distribution):
Example class: Exponential Distribution. Example class: Exponential Distribution.
Args: Args:
rate (float, list, numpy.ndarray, Tensor): The inverse scale. rate (float, list, numpy.ndarray, Tensor): The inverse scale. Default: None.
seed (int): The seed used in sampling. The global seed is used if it is None. Default: None. seed (int): The seed used in sampling. The global seed is used if it is None. Default: None.
dtype (mindspore.dtype): The type of the event samples. Default: mstype.float32. dtype (mindspore.dtype): The type of the event samples. Default: mstype.float32.
name (str): The name of the distribution. Default: 'Exponential'. name (str): The name of the distribution. Default: 'Exponential'.
@ -156,10 +156,11 @@ class Exponential(Distribution):
self.uniform = C.uniform self.uniform = C.uniform
def extend_repr(self): def extend_repr(self):
"""Display instance object as string."""
if self.is_scalar_batch: if self.is_scalar_batch:
s = f'rate = {self.rate}' s = 'rate = {}'.format(self.rate)
else: else:
s = f'batch_shape = {self._broadcast_shape}' s = 'batch_shape = {}'.format(self._broadcast_shape)
return s return s
@property @property

View File

@ -180,10 +180,11 @@ class Gamma(Distribution):
self.igamma = nn.IGamma() self.igamma = nn.IGamma()
def extend_repr(self): def extend_repr(self):
"""Display instance object as string."""
if self.is_scalar_batch: if self.is_scalar_batch:
s = f'concentration = {self._concentration}, rate = {self._rate}' s = 'concentration = {}, rate = {}'.format(self._concentration, self._rate)
else: else:
s = f'batch_shape = {self._broadcast_shape}' s = 'batch_shape = {}'.format(self._broadcast_shape)
return s return s
@property @property

View File

@ -165,10 +165,11 @@ class Geometric(Distribution):
self.uniform = C.uniform self.uniform = C.uniform
def extend_repr(self): def extend_repr(self):
"""Display instance object as string."""
if not self.is_scalar_batch: if not self.is_scalar_batch:
s = f'batch_shape = {self._broadcast_shape}' s = 'batch_shape = {}'.format(self._broadcast_shape)
else: else:
s = f'probs = {self.probs}' s = 'probs = {}'.format(self.probs)
return s return s
@property @property

View File

@ -112,10 +112,11 @@ class Gumbel(TransformedDistribution):
return self._scale return self._scale
def extend_repr(self): def extend_repr(self):
"""Display instance object as string."""
if self.is_scalar_batch: if self.is_scalar_batch:
str_info = f'loc = {self._loc}, scale = {self._scale}' str_info = 'loc = {}, scale = {}'.format(self._loc, self._scale)
else: else:
str_info = f'batch_shape = {self._broadcast_shape}' str_info = 'batch_shape = {}'.format(self._broadcast_shape)
return str_info return str_info
def _get_dist_type(self): def _get_dist_type(self):

View File

@ -129,10 +129,11 @@ class LogNormal(msd.TransformedDistribution):
return loc, scale return loc, scale
def extend_repr(self): def extend_repr(self):
"""Display instance object as string."""
if self.is_scalar_batch: if self.is_scalar_batch:
s = f'loc = {self.loc}, scale = {self.scale}' s = 'loc = {}, scale = {}'.format(self.loc, self.scale)
else: else:
s = f'batch_shape = {self.broadcast_shape}' s = 'batch_shape = {}'.format(self.broadcast_shape)
return s return s
def _mean(self, loc=None, scale=None): def _mean(self, loc=None, scale=None):

View File

@ -173,10 +173,11 @@ class Logistic(Distribution):
return self.select(too_small, too_small_value, self.select(too_large, too_large_value, y)) return self.select(too_small, too_small_value, self.select(too_large, too_large_value, y))
def extend_repr(self): def extend_repr(self):
"""Display instance object as string."""
if self.is_scalar_batch: if self.is_scalar_batch:
s = f'location = {self._loc}, scale = {self._scale}' s = 'location = {}, scale = {}'.format(self._loc, self._scale)
else: else:
s = f'batch_shape = {self._broadcast_shape}' s = 'batch_shape = {}'.format(self._broadcast_shape)
return s return s
@property @property
@ -291,6 +292,7 @@ class Logistic(Distribution):
value = self.cast(value, self.dtype) value = self.cast(value, self.dtype)
loc, scale = self._check_param_type(loc, scale) loc, scale = self._check_param_type(loc, scale)
z = (value - loc) / scale z = (value - loc) / scale
# pylint: disable=E1130
return -self.softplus(-z) return -self.softplus(-z)
def _survival_function(self, value, loc=None, scale=None): def _survival_function(self, value, loc=None, scale=None):
@ -327,6 +329,7 @@ class Logistic(Distribution):
value = self.cast(value, self.dtype) value = self.cast(value, self.dtype)
loc, scale = self._check_param_type(loc, scale) loc, scale = self._check_param_type(loc, scale)
z = (value - loc) / scale z = (value - loc) / scale
# pylint: disable=E1130
return -self.softplus(z) return -self.softplus(z)
def _sample(self, shape=(), loc=None, scale=None): def _sample(self, shape=(), loc=None, scale=None):

View File

@ -164,10 +164,11 @@ class Normal(Distribution):
self.sqrt = P.Sqrt() self.sqrt = P.Sqrt()
def extend_repr(self): def extend_repr(self):
"""Display instance object as string."""
if self.is_scalar_batch: if self.is_scalar_batch:
s = f'mean = {self._mean_value}, standard deviation = {self._sd_value}' s = 'mean = {}, standard deviation = {}'.format(self._mean_value, self._sd_value)
else: else:
s = f'batch_shape = {self._broadcast_shape}' s = 'batch_shape = {}'.format(self._broadcast_shape)
return s return s
def _get_dist_type(self): def _get_dist_type(self):

View File

@ -155,10 +155,11 @@ class Poisson(Distribution):
return self._rate return self._rate
def extend_repr(self): def extend_repr(self):
"""Display instance object as string."""
if self.is_scalar_batch: if self.is_scalar_batch:
s = f'rate = {self.rate}' s = 'rate = {}'.format(self.rate)
else: else:
s = f'batch_shape = {self._broadcast_shape}' s = 'batch_shape = {}'.format(self._broadcast_shape)
return s return s
def _get_dist_type(self): def _get_dist_type(self):
@ -219,6 +220,7 @@ class Poisson(Distribution):
safe_x = self.select(self.less(value, zeros), zeros, value) safe_x = self.select(self.less(value, zeros), zeros, value)
y = log_rate * safe_x - self.lgamma(safe_x + 1.) y = log_rate * safe_x - self.lgamma(safe_x + 1.)
comp = self.equal(value, safe_x) comp = self.equal(value, safe_x)
# pylint: disable=E1130
log_unnormalized_prob = self.select(comp, y, -inf) log_unnormalized_prob = self.select(comp, y, -inf)
log_normalization = self.exp(log_rate) log_normalization = self.exp(log_rate)
return log_unnormalized_prob - log_normalization return log_unnormalized_prob - log_normalization

View File

@ -170,10 +170,11 @@ class Uniform(Distribution):
self.uniform = C.uniform self.uniform = C.uniform
def extend_repr(self): def extend_repr(self):
"""Display instance object as string."""
if self.is_scalar_batch: if self.is_scalar_batch:
s = f'low = {self.low}, high = {self.high}' s = 'low = {}, high = {}'.format(self.low, self.high)
else: else:
s = f'batch_shape = {self._broadcast_shape}' s = 'batch_shape = {}'.format(self._broadcast_shape)
return s return s
@property @property