redesigned bijector classes, changed bacth_shape calculation in transformed distribution and upgraded dtype logic of bijector class

This commit is contained in:
Xun Deng 2020-10-22 14:01:44 -04:00
parent 05643cb20c
commit fb0263f869
13 changed files with 451 additions and 158 deletions

View File

@ -16,8 +16,10 @@
from mindspore import context
from mindspore.nn.cell import Cell
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore._checkparam import Validator as validator
from ..distribution._utils.utils import CheckTensor, cast_to_tensor
from ..distribution._utils.utils import CheckTensor, cast_to_tensor, raise_type_error
from ..distribution import Distribution
from ..distribution import TransformedDistribution
@ -32,6 +34,17 @@ class Bijector(Cell):
name (str): The name of the Bijector. Default: None.
dtype (mindspore.dtype): The type of the distributions that the Bijector can operate on. Default: None.
param (dict): The parameters used to initialize the Bijector. Default: None.
Note:
`dtype` of bijector represents the type of the distributions that the bijector could operate on.
When `dtype` is None, there is no enforcement on the type of input value except that the input value
has to be float type. During initilization, when `dtype` is None, there is no enforcement on the dtype
of the parameters. All parameters should have the same float type, otherwise a TypeError will be raised.
Specifically, the parameter type will follow the dtype of the input value, i.e. parameters of the bijector
will be casted into the same type as input value when `dtype`is None.
When `dtype` is specified, it is forcing the parameters and input value to be the same dtype as `dtype`.
When the type of parameters or the type of the input value is not the same as `dtype`, a TypeError will be
raised. Only subtype of mindspore.float_type can be used to specify bijector's `dtype`.
"""
def __init__(self,
@ -48,6 +61,8 @@ class Bijector(Cell):
validator.check_value_type(
'is_constant_jacobian', is_constant_jacobian, [bool], name)
validator.check_value_type('is_injective', is_injective, [bool], name)
if dtype is not None:
validator.check_type_name("dtype", dtype, mstype.float_type, type(self).__name__)
self._name = name
self._dtype = dtype
self._parameters = {}
@ -57,6 +72,12 @@ class Bijector(Cell):
continue
if not(k == 'self' or k.startswith('_')):
self._parameters[k] = param[k]
# if no bijector is used as an argument during initilization
if 'bijector' not in param.keys():
self._batch_shape = self._calc_batch_shape()
self._is_scalar_batch = self._check_is_scalar_batch()
self._is_constant_jacobian = is_constant_jacobian
self._is_injective = is_injective
@ -68,6 +89,8 @@ class Bijector(Cell):
self.dtype_base = P.DType()
self.shape_base = P.Shape()
self.fill_base = P.Fill()
self.sametypeshape_base = P.SameTypeShape()
self.issubclass_base = P.IsSubClass()
@property
def name(self):
@ -89,6 +112,38 @@ class Bijector(Cell):
def is_injective(self):
return self._is_injective
@property
def batch_shape(self):
return self._batch_shape
@property
def is_scalar_batch(self):
return self._is_scalar_batch
def _check_value_dtype(self, value):
"""
Firstly check if the input value is Tensor. Then, if `self.dtype` is None, check
if the input tensor is or can be directly cast into a float tensor.
If `self.dtype` is not None, check if the input tensor's dtype is `self.dtype`.
"""
self.checktensor(value, 'input value of bijector')
value_type = self.dtype_base(value)
if self.dtype is None:
if self.issubclass_base(value_type, mstype.float_):
return value
return raise_type_error('input value of bijector', value_type, mstype.float_)
dtype_tensor = self.fill_base(self.dtype, self.shape_base(value), 0.0)
self.sametypeshape_base(value, dtype_tensor)
return value
def _shape_mapping(self, shape):
shape_tensor = self.fill_base(self.parameter_type, shape, 0.0)
dist_shape_tensor = self.fill_base(self.parameter_type, self.batch_shape, 0.0)
return (shape_tensor + dist_shape_tensor).shape
def shape_mapping(self, shape):
return self._shape_mapping(shape)
def _add_parameter(self, value, name):
"""
Cast `value` to a tensor and add it to `self.default_parameters`.
@ -98,26 +153,51 @@ class Bijector(Cell):
if not hasattr(self, 'default_parameters'):
self.default_parameters = []
self.parameter_names = []
self.common_dtype = None
# cast value to a tensor if it is not None
value_t = None if value is None else cast_to_tensor(value, self.parameter_type)
self.default_parameters += [value_t,]
if isinstance(value, bool) or value is None:
raise TypeError(f"{name} cannot be type {type(value)}")
value_t = Tensor(value)
# if the bijector's dtype is not specified
if self.dtype is None:
if self.common_dtype is None:
self.common_dtype = value_t.dtype
elif value_t.dtype != self.common_dtype:
raise TypeError(f"{name} should have the same dtype as other arguments.")
# check if the dtype of the input_parameter agrees with the bijector's dtype
elif value_t.dtype != self.dtype:
raise TypeError(f"{name} should have the same dtype as the bijector's dtype.")
self.default_parameters += [value,]
self.parameter_names += [name,]
return value_t
def _calc_event_shape(self):
def _calc_batch_shape(self):
"""
Calculate event_shape based on parameters.
Calculate batch_shape based on parameters.
"""
broadcast_shape = None
for param in self.default_parameters:
if broadcast_shape is None:
broadcast_shape = self.shape_base(param)
broadcast_shape_tensor = self.fill_base(self.parameter_type, broadcast_shape, 0.0)
param_dict = self.parameters['param_dict']
broadcast_shape_tensor = None
for value in param_dict.values():
if value is None:
return None
if broadcast_shape_tensor is None:
broadcast_shape_tensor = cast_to_tensor(value)
else:
broadcast_shape = self.shape_base(param + broadcast_shape_tensor)
broadcast_shape_tensor = self.fill_base(self.parameter_type, broadcast_shape, 0.0)
return broadcast_shape
value = cast_to_tensor(value)
broadcast_shape_tensor = (value + broadcast_shape_tensor)
return broadcast_shape_tensor.shape
def _check_is_scalar_batch(self):
"""
Check if the parameters used during initialization are scalars.
"""
param_dict = self.parameters['param_dict']
for value in param_dict.values():
if value is None:
continue
if not isinstance(value, (int, float)):
return False
return True
def _check_value(self, value, name):
"""
@ -127,32 +207,35 @@ class Bijector(Cell):
return value
def cast_param_by_value(self, value, para):
"""
Cast the parameter(s) of the bijector to be the same type of input_value.
"""
local = self.cast_base(para, self.dtype_base(value))
return local
def forward(self, *args, **kwargs):
def forward(self, value, *args, **kwargs):
"""
Forward transformation: transform the input value to another distribution.
"""
return self._forward(*args, **kwargs)
return self._forward(value, *args, **kwargs)
def inverse(self, *args, **kwargs):
def inverse(self, value, *args, **kwargs):
"""
Inverse transformation: transform the input value back to the original distribution.
"""
return self._inverse(*args, **kwargs)
return self._inverse(value, *args, **kwargs)
def forward_log_jacobian(self, *args, **kwargs):
def forward_log_jacobian(self, value, *args, **kwargs):
"""
Logarithm of the derivative of the forward transformation.
"""
return self._forward_log_jacobian(*args, **kwargs)
return self._forward_log_jacobian(value, *args, **kwargs)
def inverse_log_jacobian(self, *args, **kwargs):
def inverse_log_jacobian(self, value, *args, **kwargs):
"""
Logarithm of the derivative of the inverse transformation.
"""
return self._inverse_log_jacobian(*args, **kwargs)
return self._inverse_log_jacobian(value, *args, **kwargs)
def __call__(self, *args, **kwargs):
"""
@ -167,7 +250,7 @@ class Bijector(Cell):
*args: args[0] shall be either a distribution or the name of a Bijector function.
"""
if isinstance(args[0], Distribution):
return TransformedDistribution(self, args[0], self.distribution.dtype)
return TransformedDistribution(self, args[0])
return super(Bijector, self).__call__(*args, **kwargs)
def construct(self, name, *args, **kwargs):

View File

@ -13,10 +13,8 @@
# limitations under the License.
# ============================================================================
"""GumbelCDF Bijector"""
from mindspore.common import dtype as mstype
from mindspore._checkparam import Validator
from mindspore.ops import operations as P
from ..distribution._utils.utils import check_greater_zero, set_param_type
from ..distribution._utils.utils import check_greater_zero
from ..distribution._utils.custom_ops import exp_generic, log_generic
from .bijector import Bijector
@ -30,12 +28,11 @@ class GumbelCDF(Bijector):
Y = \exp(-\exp(\frac{-(X - loc)}{scale}))
Note:
For `reverse` and `reverse_log_jacobian`, input should be in range of (0, 1).
For `inverse` and `inverse_log_jacobian`, input should be in range of (0, 1).
Args:
loc (int, float, list, numpy.ndarray, Tensor): The location. Default: 0..
scale (int, float, list, numpy.ndarray, Tensor): The scale. Default: 1.0.
dtype (mindspore.dtype): Type of the distribution which the bijector operates on. Default: float32.
loc (float, list, numpy.ndarray, Tensor): The location. Default: 0..
scale (float, list, numpy.ndarray, Tensor): The scale. Default: 1.0.
name (str): The name of the Bijector. Default: 'Gumbel_CDF'.
Examples:
@ -61,22 +58,18 @@ class GumbelCDF(Bijector):
def __init__(self,
loc=0.0,
scale=1.0,
dtype=mstype.float32,
name='GumbelCDF'):
"""
Constructor of GumbelCDF Bijector.
"""
param = dict(locals())
valid_dtype = mstype.float_type + mstype.int_type + mstype.uint_type
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
parameter_type = set_param_type({'loc': loc, "scale": scale}, dtype)
super(GumbelCDF, self).__init__(name=name, dtype=dtype, param=param)
param['param_dict'] = {'loc': loc, 'scale': scale}
super(GumbelCDF, self).__init__(name=name, param=param)
self._parameter_type = parameter_type
self._loc = self._add_parameter(loc, 'loc')
self._scale = self._add_parameter(scale, 'scale')
check_greater_zero(self._scale, "scale")
self._event_shape = self._calc_event_shape()
self.cast = P.Cast()
self.exp = exp_generic
@ -91,38 +84,34 @@ class GumbelCDF(Bijector):
def scale(self):
return self._scale
@property
def event_shape(self):
return self._event_shape
@property
def parameter_type(self):
return self._parameter_type
def extend_repr(self):
return f'loc = {self.loc}, scale = {self.scale}'
def shape_mapping(self, shape):
return shape
if self.is_scalar_batch:
str_info = f'loc = {self.loc}, scale = {self.scale}'
else:
str_info = f'batch_shape = {self.batch_shape}'
return str_info
def _forward(self, x):
x = self._check_value(x, 'value')
x = self.cast(x, self.parameter_type)
z = (x - self.loc) / self.scale
x = self._check_value_dtype(x)
loc_local = self.cast_param_by_value(x, self.loc)
scale_local = self.cast_param_by_value(x, self.scale)
z = (x - loc_local) / scale_local
return self.exp(-self.exp(-z))
def _inverse(self, y):
y = self._check_value(y, 'value')
y = self.cast(y, self.parameter_type)
return self.loc - self.scale * self.log(-self.log(y))
y = self._check_value_dtype(y)
loc_local = self.cast_param_by_value(y, self.loc)
scale_local = self.cast_param_by_value(y, self.scale)
return loc_local - scale_local * self.log(-self.log(y))
def _forward_log_jacobian(self, x):
x = self._check_value(x, 'value')
x = self.cast(x, self.parameter_type)
z = (x - self.loc) / self.scale
return -z - self.exp(-z) - self.log(self.scale)
x = self._check_value_dtype(x)
loc_local = self.cast_param_by_value(x, self.loc)
scale_local = self.cast_param_by_value(x, self.scale)
z = (x - loc_local) / scale_local
return -z - self.exp(-z) - self.log(scale_local)
def _inverse_log_jacobian(self, y):
y = self._check_value(y, 'value')
y = self.cast(y, self.parameter_type)
return self.log(self.scale / (-1. * y * self.log(y)))
y = self._check_value_dtype(y)
scale_local = self.cast_param_by_value(y, self.scale)
return self.log(scale_local / (-1. * y * self.log(y)))

View File

@ -53,23 +53,17 @@ class Invert(Bijector):
name = (name + bijector.name) if name == 'Invert' else name
super(Invert, self).__init__(is_constant_jacobian=bijector.is_constant_jacobian,
is_injective=bijector.is_injective,
dtype=bijector.dtype,
name=name,
dtype=bijector.dtype,
param=param)
self._bijector = bijector
if hasattr(self._bijector, 'event_shape'):
self._event_shape = self.bijector.event_shape
else:
self._event_shape = ()
self._batch_shape = self.bijector.batch_shape
self._is_scalar_batch = self.bijector.is_scalar_batch
@property
def bijector(self):
return self._bijector
@property
def event_shape(self):
return self._event_shape
def inverse(self, y):
return self.bijector("forward", y)

View File

@ -14,8 +14,7 @@
# ============================================================================
"""Power Bijector"""
from mindspore.ops import operations as P
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from ..distribution._utils.utils import check_greater_equal_zero
from ..distribution._utils.custom_ops import exp_generic, expm1_generic, log_generic, log1p_generic
from .bijector import Bijector
@ -37,7 +36,7 @@ class PowerTransform(Bijector):
ValueError: When the power is less than 0 or is not known statically.
Args:
power (int or float): The scale factor. Default: 0.
power (float, list, numpy.ndarray, Tensor): The scale factor. Default: 0.
name (str): The name of the bijector. Default: 'PowerTransform'.
Examples:
@ -64,10 +63,11 @@ class PowerTransform(Bijector):
power=0,
name='PowerTransform'):
param = dict(locals())
param['param_dict'] = {'power': power}
super(PowerTransform, self).__init__(name=name, param=param)
validator.check_value_type('power', power, [int, float], self.name)
validator.check_number("power", power, 0, Rel.GE, self.name)
self._power = power
self._power = self._add_parameter(power, 'power')
check_greater_equal_zero(self._power, 'Power')
self.pow = P.Pow()
self.dtypeop = P.DType()
self.cast = P.Cast()
@ -81,13 +81,15 @@ class PowerTransform(Bijector):
return self._power
def extend_repr(self):
return f'power = {self.power}'
if self.is_scalar_batch:
str_info = f'power = {self.power}'
else:
str_info = f'batch_shape = {self.batch_shape}'
return str_info
def shape_mapping(self, shape):
return shape
def _forward(self, x):
x = self._check_value(x, 'value')
x = self._check_value_dtype(x)
power_local = self.cast_param_by_value(x, self.power)
if power_local == 0:
forward_v = self.exp(x)
@ -96,7 +98,7 @@ class PowerTransform(Bijector):
return forward_v
def _inverse(self, y):
y = self._check_value(y, 'value')
y = self._check_value_dtype(y)
power_local = self.cast_param_by_value(y, self.power)
if power_local == 0:
inverse_v = self.log(y)
@ -116,7 +118,7 @@ class PowerTransform(Bijector):
f'(x) = e^\frac{\log(xc + 1)}{c} * \frac{1}{xc + 1}
\log(f'(x)) = (\frac{1}{c} - 1) * \log(xc + 1)
"""
x = self._check_value(x, 'value')
x = self._check_value_dtype(x)
power_local = self.cast_param_by_value(x, self.power)
if power_local == 0:
forward_log_j = x
@ -136,7 +138,7 @@ class PowerTransform(Bijector):
f'(x) = \frac{e^c\log(y)}{y}
\log(f'(x)) = \log(\frac{e^c\log(y)}{y}) = (c-1) * \log(y)
"""
y = self._check_value(y, 'value')
y = self._check_value_dtype(y)
power_local = self.cast_param_by_value(y, self.power)
inverse_log_j = (power_local - 1) * self.log(y)
return inverse_log_j

View File

@ -14,8 +14,6 @@
# ============================================================================
"""Scalar Affine Bijector"""
from mindspore.ops import operations as P
from mindspore._checkparam import Validator as validator
from ..distribution._utils.utils import cast_to_tensor
from ..distribution._utils.custom_ops import log_generic
from .bijector import Bijector
@ -30,10 +28,14 @@ class ScalarAffine(Bijector):
where a is the scale factor and b is the shift factor.
Args:
scale (float): The scale factor. Default: 1.0.
shift (float): The shift factor. Default: 0.0.
scale (float, list, numpy.ndarray, Tensor): The scale factor. Default: 1.0.
shift (float, list, numpy.ndarray, Tensor): The shift factor. Default: 0.0.
name (str): The name of the bijector. Default: 'ScalarAffine'.
Note:
If `shift`, `scale` are passed in as numpy.ndarray or tensor, they have to have
the same dtype otherwise an error will be raised.
Examples:
>>> # To initialize a ScalarAffine bijector of scale 1 and shift 2.
>>> scalaraffine = nn.probability.bijector.ScalarAffine(1, 2)
@ -61,10 +63,7 @@ class ScalarAffine(Bijector):
Constructor of ScalarAffine Bijector.
"""
param = dict(locals())
validator.check_value_type(
'scale', scale, [int, float], type(self).__name__)
validator.check_value_type(
'shift', shift, [int, float], type(self).__name__)
param['param_dict'] = {'scale': scale, 'shift': shift}
super(ScalarAffine, self).__init__(
is_constant_jacobian=True,
is_injective=True,
@ -72,8 +71,8 @@ class ScalarAffine(Bijector):
dtype=None,
param=param)
self._scale = cast_to_tensor(scale)
self._shift = cast_to_tensor(shift)
self._scale = self._add_parameter(scale, 'scale')
self._shift = self._add_parameter(shift, 'shift')
self.abs = P.Abs()
self.oneslike = P.OnesLike()
@ -90,17 +89,18 @@ class ScalarAffine(Bijector):
return self._shift
def extend_repr(self):
return f'scale = {self.scale}, shift = {self.shift}'
def shape_mapping(self, shape):
return shape
if self.is_scalar_batch:
str_info = f'scale = {self.scale}, shift = {self.shift}'
else:
str_info = f'batch_shape = {self.batch_shape}'
return str_info
def _forward(self, x):
r"""
.. math::
f(x) = a * x + b
"""
x = self._check_value(x, 'value')
x = self._check_value_dtype(x)
scale_local = self.cast_param_by_value(x, self.scale)
shift_local = self.cast_param_by_value(x, self.shift)
forward_v = scale_local * x + shift_local * self.oneslike(x)
@ -111,7 +111,7 @@ class ScalarAffine(Bijector):
.. math::
f(y) = \frac{y - b}{a}
"""
y = self._check_value(y, 'value')
y = self._check_value_dtype(y)
scale_local = self.cast_param_by_value(y, self.scale)
shift_local = self.cast_param_by_value(y, self.shift)
inverse_v = (y - shift_local) / scale_local
@ -124,7 +124,7 @@ class ScalarAffine(Bijector):
f'(x) = a
\log(f'(x)) = \log(a)
"""
x = self._check_value(x, 'value')
x = self._check_value_dtype(x)
scale_local = self.cast_param_by_value(x, self.scale)
forward_log_j = self.log(self.abs(scale_local))
return forward_log_j
@ -136,7 +136,7 @@ class ScalarAffine(Bijector):
f'(x) = \frac{1.0}{a}
\log(f'(x)) = - \log(a)
"""
y = self._check_value(y, 'value')
y = self._check_value_dtype(y)
scale_local = self.cast_param_by_value(y, self.scale)
inverse_log_j = -1. * self.log(self.abs(scale_local))
return inverse_log_j

View File

@ -16,8 +16,6 @@
import numpy as np
from mindspore.ops import operations as P
from mindspore.nn.layer.activation import LogSigmoid
from mindspore._checkparam import Validator as validator
from ..distribution._utils.utils import cast_to_tensor
from ..distribution._utils.custom_ops import exp_generic, expm1_generic, log_generic
from .bijector import Bijector
@ -32,7 +30,7 @@ class Softplus(Bijector):
where k is the sharpness factor.
Args:
sharpness (float): The scale factor. Default: 1.0.
sharpness (float, list, numpy.ndarray, Tensor): The scale factor. Default: 1.0.
name (str): The name of the Bijector. Default: 'Softplus'.
Examples:
@ -61,10 +59,9 @@ class Softplus(Bijector):
Constructor of Softplus Bijector.
"""
param = dict(locals())
validator.check_value_type('sharpness', sharpness,
[int, float], type(self).__name__)
super(Softplus, self).__init__(name=name, param=param)
self._sharpness = cast_to_tensor(sharpness)
param['param_dict'] = {'sharpness': sharpness}
super(Softplus, self).__init__(name=name, dtype=None, param=param)
self._sharpness = self._add_parameter(sharpness, 'sharpness')
self.exp = exp_generic
self.log = log_generic
@ -118,13 +115,14 @@ class Softplus(Bijector):
return self._sharpness
def extend_repr(self):
return f'sharpness = {self.sharpness}'
def shape_mapping(self, shape):
return shape
if self.is_scalar_batch:
str_info = f'sharpness = {self.sharpness}'
else:
str_info = f'batch_shape = {self.batch_shape}'
return str_info
def _forward(self, x):
x = self._check_value(x, 'value')
x = self._check_value_dtype(x)
sharpness_local = self.cast_param_by_value(x, self.sharpness)
scaled_value = sharpness_local * x
forward_v = self.softplus(scaled_value) / sharpness_local
@ -136,7 +134,7 @@ class Softplus(Bijector):
f(x) = \frac{\log(1 + e^{kx}))}{k}
f^{-1}(y) = \frac{\log(e^{ky} - 1)}{k}
"""
y = self._check_value(y, 'value')
y = self._check_value_dtype(y)
sharpness_local = self.cast_param_by_value(y, self.sharpness)
scaled_value = sharpness_local * y
inverse_v = self.inverse_softplus(scaled_value) / sharpness_local
@ -149,7 +147,7 @@ class Softplus(Bijector):
f'(x) = \frac{e^{kx}}{ 1 + e^{kx}}
\log(f'(x)) = kx - \log(1 + e^{kx}) = kx - f(kx)
"""
x = self._check_value(x, 'value')
x = self._check_value_dtype(x)
sharpness_local = self.cast_param_by_value(x, self.sharpness)
scaled_value = sharpness_local * x
forward_log_j = self.log_sigmoid(scaled_value)
@ -162,7 +160,7 @@ class Softplus(Bijector):
f'(y) = \frac{e^{ky}}{e^{ky} - 1}
\log(f'(y)) = ky - \log(e^{ky} - 1) = ky - f(ky)
"""
y = self._check_value(y, 'value')
y = self._check_value_dtype(y)
sharpness_local = self.cast_param_by_value(y, self.sharpness)
scaled_value = sharpness_local * y
inverse_log_j = scaled_value - self.inverse_softplus(scaled_value)

View File

@ -229,6 +229,11 @@ def raise_not_implemented_util(func_name, obj, *args, **kwargs):
raise NotImplementedError(
f"{func_name} is not implemented for {obj} distribution.")
@constexpr
def raise_type_error(name, cur_type, required_type):
raise TypeError(
f"For {name} , the type should be or be subclass of {required_type}, but got {cur_type}")
@constexpr
def check_distribution_name(name, expected_name):
@ -304,7 +309,7 @@ def set_param_type(args, hint_type):
TypeError: if tensors in args are not the same dtype.
"""
int_type = mstype.int_type + mstype.uint_type
if hint_type in int_type:
if hint_type in int_type or hint_type is None:
hint_type = mstype.float32
common_dtype = None
for name, arg in args.items():

View File

@ -72,13 +72,12 @@ class Distribution(Cell):
if not(k == 'self' or k.startswith('_')):
self._parameters[k] = param[k]
# some attributes
if 'distribution' in self.parameters.keys():
self.parameter_type = self.parameters['distribution'].parameter_type
else:
# if not a transformed distribution, set the following attribute
if 'distribution' not in self.parameters.keys():
self.parameter_type = set_param_type(self.parameters['param_dict'], dtype)
self._broadcast_shape = self._calc_broadcast_shape()
self._is_scalar_batch = self._check_is_scalar_batch()
self._batch_shape = self._calc_batch_shape()
self._is_scalar_batch = self._check_is_scalar_batch()
self._broadcast_shape = self._batch_shape
# set the function to call according to the derived class's attributes
self._set_prob()
@ -128,6 +127,10 @@ class Distribution(Cell):
def is_scalar_batch(self):
return self._is_scalar_batch
@property
def batch_shape(self):
return self._batch_shape
@property
def broadcast_shape(self):
return self._broadcast_shape
@ -208,8 +211,6 @@ class Distribution(Cell):
"""
Check if the parameters used during initialization are scalars.
"""
if 'distribution' in self.parameters.keys():
return self.parameters['distribution'].is_scalar_batch
param_dict = self.parameters['param_dict']
for value in param_dict.values():
if value is None:
@ -218,12 +219,10 @@ class Distribution(Cell):
return False
return True
def _calc_broadcast_shape(self):
def _calc_batch_shape(self):
"""
Calculate the broadcast shape of the parameters used during initialization.
"""
if 'distribution' in self.parameters.keys():
return self.parameters['distribution'].broadcast_shape
param_dict = self.parameters['param_dict']
broadcast_shape_tensor = None
for value in param_dict.values():
@ -362,14 +361,14 @@ class Distribution(Cell):
"""
return self._get_dist_args(*args, **kwargs)
def _get_dist_type(self, *args, **kwargs):
return raise_not_implemented_util('get_dist_type', self.name, *args, **kwargs)
def _get_dist_type(self):
return raise_not_implemented_util('get_dist_type', self.name)
def get_dist_type(self, *args, **kwargs):
def get_dist_type(self):
"""
Return the type of the distribution.
"""
return self._get_dist_type(*args, **kwargs)
return self._get_dist_type()
def _raise_not_implemented_error(self, func_name):
name = self.name
@ -751,5 +750,5 @@ class Distribution(Cell):
if name == 'get_dist_args':
return self._get_dist_args(*args, **kwargs)
if name == 'get_dist_type':
return self._get_dist_type(*args, **kwargs)
return self._get_dist_type()
return raise_not_implemented_util(name, self.name, *args, **kwargs)

View File

@ -103,17 +103,12 @@ class Gumbel(TransformedDistribution):
"""
valid_dtype = mstype.float_type
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
gumbel_cdf = msb.GumbelCDF(loc, scale, dtype)
gumbel_cdf = msb.GumbelCDF(loc, scale)
super(Gumbel, self).__init__(
distribution=msd.Uniform(0.0, 1.0, dtype=dtype),
bijector=msb.Invert(gumbel_cdf),
seed=seed, name=name)
self.parameter_type = gumbel_cdf.parameter_type
self._broadcast_shape = gumbel_cdf.event_shape
if self._broadcast_shape != ():
self._is_scalar_batch = False
# overwrite default_parameters and parameter_names
self._reset_parameters()
self._loc = self._add_parameter(loc, 'loc')
@ -202,6 +197,7 @@ class Gumbel(TransformedDistribution):
where z = \frac{x - loc}{scale}
"""
value = self._check_value(value, 'value')
value = self.cast(value, self.dtype)
z = (value - self.loc) / self.scale
return -(z + self.exp(-z)) - self.log(self.scale)
@ -210,6 +206,8 @@ class Gumbel(TransformedDistribution):
.. math::
cdf_pdf(X) = \exp(-\exp(-\frac{x - loc}{scale})
"""
value = self._check_value(value, 'value')
value = self.cast(value, self.dtype)
return self._gumbel_bijector("forward", value)
def _cross_entropy(self, dist, loc_b, scale_b):
@ -251,12 +249,14 @@ class Gumbel(TransformedDistribution):
self.expm1((loc_b - self.loc) / scale_b + self.lgamma(self.scale / scale_b + 1.))
def _sample(self, shape=()):
shape = self.checktuple(shape, 'shape')
origin_shape = shape + self._broadcast_shape
if origin_shape == ():
sample_shape = (1,)
else:
sample_shape = origin_shape
org_sample = self.distribution("sample", sample_shape)
org_sample = self.cast(org_sample, self.dtype)
value = self.bijector("forward", org_sample)
if origin_shape == ():
value = self.squeeze(value)

View File

@ -137,6 +137,11 @@ class LogNormal(msd.TransformedDistribution):
bijector=msb.Exp(),
seed=seed, name=name)
# overwrite default_parameters and parameter_names
self._reset_parameters()
self._loc = self._add_parameter(loc, 'loc')
self._scale = self._add_parameter(scale, 'scale')
self.log_2pi = np.log(2 * np.pi)
#ops needed for the class
@ -154,12 +159,12 @@ class LogNormal(msd.TransformedDistribution):
@property
def loc(self):
"""Distribution parameter for the pre-transformed mean."""
return self.distribution("mean")
return self._loc
@property
def scale(self):
"""Distribution parameter for the pre-transformed standard deviation."""
return self.distribution("sd")
return self._scale
def _get_dist_type(self):
return "LogNormal"
@ -168,18 +173,18 @@ class LogNormal(msd.TransformedDistribution):
if loc is not None:
self.checktensor(loc, 'loc')
else:
loc = self.distribution("mean")
loc = self.loc
if scale is not None:
self.checktensor(scale, 'scale')
else:
scale = self.distribution("sd")
scale = self.scale
return loc, scale
def extend_repr(self):
if self.is_scalar_batch:
s = f'loc = {self._mean_value}, scale = {self._sd_value}'
s = f'loc = {self.loc}, scale = {self.scale}'
else:
s = f'batch_shape = {self._broadcast_shape}'
s = f'batch_shape = {self.broadcast_shape}'
return s
def _mean(self, loc=None, scale=None):

View File

@ -16,6 +16,7 @@
import numpy as np
from mindspore._checkparam import Validator as validator
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
import mindspore.nn as nn
from .distribution import Distribution
from ._utils.utils import raise_not_impl_error
@ -30,7 +31,7 @@ class TransformedDistribution(Distribution):
Args:
bijector (Bijector): The transformation to perform.
distribution (Distribution): The original distribution.
distribution (Distribution): The original distribution. Must has dtype of mindspore.float_type.
seed (int): The seed is used in sampling. The global seed is used if it is None. Default:None.
If this seed is given when a TransformedDistribution object is initialised, the object's sampling function
will use this seed; elsewise, the underlying distribution's seed will be used.
@ -40,6 +41,12 @@ class TransformedDistribution(Distribution):
The arguments used to initialize the original distribution cannot be None.
For example, mynormal = nn.Normal(dtype=dtyple.float32) cannot be used to initialized a
TransformedDistribution since `mean` and `sd` are not specified.
`batch_shape` is the batch_shape of the original distribution.
`broadcast_shape` is the broadcast shape between the original distribution and bijector.
`is_scalar_batch` is only true if both the original distribution and the bijector are scalar batches.
`default_parameters`, `parameter_names` and `parameter_type` are set to be consistent with the original
distribution. Derived class can overwrite `default_parameters` and `parameter_names` by calling
`reset_parameters` followed by `add_parameter`.
Examples:
>>> # To initialize a transformed distribution, e.g. a lognormal distribution,
@ -75,28 +82,34 @@ class TransformedDistribution(Distribution):
[nn.probability.bijector.Bijector], type(self).__name__)
validator.check_value_type('distribution', distribution,
[Distribution], type(self).__name__)
validator.check_type_name("dtype", distribution.dtype, mstype.float_type, type(self).__name__)
super(TransformedDistribution, self).__init__(seed, distribution.dtype, name, param)
self._bijector = bijector
self._distribution = distribution
self._is_linear_transformation = bijector.is_constant_jacobian
self.default_parameters = distribution.default_parameters
self.parameter_names = distribution.parameter_names
# set attributes
self._is_linear_transformation = self.bijector.is_constant_jacobian
self._dtype = self.distribution.dtype
self._is_scalar_batch = self.distribution.is_scalar_batch and self.bijector.is_scalar_batch
self._batch_shape = self.distribution.batch_shape
self.default_parameters = self.distribution.default_parameters
self.parameter_names = self.distribution.parameter_names
# by default, set the parameter_type to be the distribution's parameter_type
self.parameter_type = self.distribution.parameter_type
self.exp = exp_generic
self.log = log_generic
self.isnan = P.IsNan()
self.cast_base = P.Cast()
self.equal_base = P.Equal()
self.select_base = P.Select()
self.fill = P.Fill()
self.fill_base = P.Fill()
# broadcast bijector batch_shape and distribution batch_shape
self._broadcast_shape = self._broadcast_bijector_dist()
# check if batch shape of the distribution and event shape is broadcastable
if hasattr(self.bijector, 'event_shape'):
event_shape_tensor = self.fill(self.dtype, self.bijector.event_shape, 0.0)
broadcast_shape_tensor = self.fill(self.dtype, self.broadcast_shape, 0.0)
self._batch_event = (event_shape_tensor + broadcast_shape_tensor).shape
else:
self._batch_event = self.broadcast_shape
@property
def bijector(self):
@ -108,12 +121,22 @@ class TransformedDistribution(Distribution):
@property
def dtype(self):
return self.distribution.dtype
return self._dtype
@property
def is_linear_transformation(self):
return self._is_linear_transformation
def _broadcast_bijector_dist(self):
"""
check if the batch shape of base distribution and the bijector is broadcastable.
"""
if self.batch_shape is None or self.bijector.batch_shape is None:
return None
bijector_shape_tensor = self.fill_base(self.dtype, self.bijector.batch_shape, 0.0)
dist_shape_tensor = self.fill_base(self.dtype, self.batch_shape, 0.0)
return (bijector_shape_tensor + dist_shape_tensor).shape
def _cdf(self, value, *args, **kwargs):
r"""
.. math::

View File

@ -0,0 +1,191 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test cases for exp"""
import numpy as np
import pytest
import mindspore.nn as nn
import mindspore.nn.probability.bijector as msb
from mindspore import Tensor
from mindspore import dtype
class MyBijector(msb.Bijector):
"""
Customized bijector class with dtype not specified.
"""
def __init__(self, param1, param2):
param = dict(locals())
param['param_dict'] = {'param1': param1, 'param2': param2}
super(MyBijector, self).__init__(name='MyBijector', dtype=None, param=param)
self._param1 = self._add_parameter(param1, 'param1')
self._param2 = self._add_parameter(param2, 'param2')
@property
def param1(self):
return self._param1
@property
def param2(self):
return self._param2
def _forward(self, value):
value = self._check_value_dtype(value)
param1_local = self.cast_param_by_value(value, self.param1)
param2_local = self.cast_param_by_value(value, self.param2)
return value * param1_local + param2_local
class MySecondBijector(msb.Bijector):
"""
Customized bijector class with dtype specified.
"""
def __init__(self, param1, param2):
param = dict(locals())
param['param_dict'] = {'param1': param1, 'param2': param2}
super(MySecondBijector, self).__init__(name='MySecondBijector', dtype=dtype.float32, param=param)
self._param1 = self._add_parameter(param1, 'param1')
self._param2 = self._add_parameter(param2, 'param2')
@property
def param1(self):
return self._param1
@property
def param2(self):
return self._param2
def _forward(self, value):
value = self._check_value_dtype(value)
param1_local = self.cast_param_by_value(value, self.param1)
param2_local = self.cast_param_by_value(value, self.param2)
return value * param1_local + param2_local
def test_arguments_same_type():
"""
Test bijector initializations.
"""
param1_1 = np.array(1.0).astype(np.float16)
param2_1 = np.array(2.0).astype(np.float32)
with pytest.raises(TypeError):
MyBijector(param1_1, param2_1)
param1_2 = Tensor(1.0, dtype=dtype.float16)
param2_2 = Tensor(2.0, dtype=dtype.float32)
with pytest.raises(TypeError):
MyBijector(param1_2, param2_2)
with pytest.raises(TypeError):
MyBijector(True, param2_2)
with pytest.raises(TypeError):
MyBijector(None, param2_2)
param1_3 = Tensor(1.0, dtype=dtype.float32)
param2_3 = Tensor(2.0, dtype=dtype.float32)
bijector = MyBijector(param1_3, param2_3)
assert isinstance(bijector, msb.Bijector)
param1_4 = np.array([1.0, 2.0]).astype(np.float16)
param2_4 = np.array([1.0, 2.0]).astype(np.float16)
bijector = MyBijector(param1_4, param2_4)
assert isinstance(bijector, msb.Bijector)
bijector = MyBijector(1.0, 2.0)
assert isinstance(bijector, msb.Bijector)
def test_arguments_with_dtype_specified():
"""
Customized bijector class with dtype not specified.
"""
param1_1 = np.array(1.0).astype(np.float16)
param2_1 = np.array(2.0).astype(np.float16)
with pytest.raises(TypeError):
MySecondBijector(param1_1, param2_1)
param1_2 = Tensor(1.0, dtype=dtype.float16)
param2_2 = Tensor(2.0, dtype=dtype.float32)
with pytest.raises(TypeError):
MySecondBijector(param1_2, param2_2)
with pytest.raises(TypeError):
MySecondBijector(True, param2_2)
with pytest.raises(TypeError):
MySecondBijector(None, param2_2)
param1_3 = Tensor(1.0, dtype=dtype.float32)
param2_3 = Tensor(2.0, dtype=dtype.float32)
bijector = MyBijector(param1_3, param2_3)
assert isinstance(bijector, msb.Bijector)
param1_4 = np.array(2.0).astype(np.float32)
param2_4 = np.array(1.0).astype(np.float32)
bijector = MyBijector(param1_4, param2_4)
assert isinstance(bijector, msb.Bijector)
class Net1(nn.Cell):
"""
Test input value when bijector's dtype is not specified.
"""
def __init__(self):
super(Net1, self).__init__()
self.bijector = MyBijector(np.array(1.0).astype(np.float32), np.array(2.0).astype(np.float32))
def construct(self, value):
return self.bijector.forward(value)
class Net2(nn.Cell):
"""
Test input value when bijector's dtype is specified.
"""
def __init__(self):
super(Net2, self).__init__()
self.bijector = MySecondBijector(np.array(1.0).astype(np.float32), np.array(2.0).astype(np.float32))
def construct(self, value):
return self.bijector.forward(value)
def test_input_value():
"""
Test validity of input value.
"""
net = Net1()
value = None
with pytest.raises(TypeError):
ans = net(value)
value = 1.0
with pytest.raises(TypeError):
ans = net(value)
value = Tensor(1.0, dtype=dtype.int32)
with pytest.raises(TypeError):
ans = net(value)
value = Tensor(1.0, dtype=dtype.float32)
ans = net(value)
assert ans.dtype == dtype.float32
value = Tensor(1.0, dtype=dtype.float16)
ans = net(value)
assert ans.dtype == dtype.float16
def test_input_value2():
"""
Test validity of input value.
"""
net = Net2()
value = None
with pytest.raises(TypeError):
ans = net(value)
value = 1.0
with pytest.raises(TypeError):
ans = net(value)
value = Tensor(1.0, dtype=dtype.int32)
with pytest.raises(TypeError):
ans = net(value)
value = Tensor(1.0, dtype=dtype.float16)
with pytest.raises(TypeError):
ans = net(value)
value = Tensor(1.0, dtype=dtype.float32)
ans = net(value)
assert ans.dtype == dtype.float32

View File

@ -1,3 +1,7 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0