diff --git a/mindspore/nn/probability/bijector/bijector.py b/mindspore/nn/probability/bijector/bijector.py index 35abd2d5210..bc395f573c3 100644 --- a/mindspore/nn/probability/bijector/bijector.py +++ b/mindspore/nn/probability/bijector/bijector.py @@ -16,8 +16,10 @@ from mindspore import context from mindspore.nn.cell import Cell from mindspore.ops import operations as P +from mindspore.common import dtype as mstype +from mindspore.common.tensor import Tensor from mindspore._checkparam import Validator as validator -from ..distribution._utils.utils import CheckTensor, cast_to_tensor +from ..distribution._utils.utils import CheckTensor, cast_to_tensor, raise_type_error from ..distribution import Distribution from ..distribution import TransformedDistribution @@ -32,6 +34,17 @@ class Bijector(Cell): name (str): The name of the Bijector. Default: None. dtype (mindspore.dtype): The type of the distributions that the Bijector can operate on. Default: None. param (dict): The parameters used to initialize the Bijector. Default: None. + + Note: + `dtype` of bijector represents the type of the distributions that the bijector could operate on. + When `dtype` is None, there is no enforcement on the type of input value except that the input value + has to be float type. During initilization, when `dtype` is None, there is no enforcement on the dtype + of the parameters. All parameters should have the same float type, otherwise a TypeError will be raised. + Specifically, the parameter type will follow the dtype of the input value, i.e. parameters of the bijector + will be casted into the same type as input value when `dtype`is None. + When `dtype` is specified, it is forcing the parameters and input value to be the same dtype as `dtype`. + When the type of parameters or the type of the input value is not the same as `dtype`, a TypeError will be + raised. Only subtype of mindspore.float_type can be used to specify bijector's `dtype`. """ def __init__(self, @@ -48,6 +61,8 @@ class Bijector(Cell): validator.check_value_type( 'is_constant_jacobian', is_constant_jacobian, [bool], name) validator.check_value_type('is_injective', is_injective, [bool], name) + if dtype is not None: + validator.check_type_name("dtype", dtype, mstype.float_type, type(self).__name__) self._name = name self._dtype = dtype self._parameters = {} @@ -57,6 +72,12 @@ class Bijector(Cell): continue if not(k == 'self' or k.startswith('_')): self._parameters[k] = param[k] + + # if no bijector is used as an argument during initilization + if 'bijector' not in param.keys(): + self._batch_shape = self._calc_batch_shape() + self._is_scalar_batch = self._check_is_scalar_batch() + self._is_constant_jacobian = is_constant_jacobian self._is_injective = is_injective @@ -68,6 +89,8 @@ class Bijector(Cell): self.dtype_base = P.DType() self.shape_base = P.Shape() self.fill_base = P.Fill() + self.sametypeshape_base = P.SameTypeShape() + self.issubclass_base = P.IsSubClass() @property def name(self): @@ -89,6 +112,38 @@ class Bijector(Cell): def is_injective(self): return self._is_injective + @property + def batch_shape(self): + return self._batch_shape + + @property + def is_scalar_batch(self): + return self._is_scalar_batch + + def _check_value_dtype(self, value): + """ + Firstly check if the input value is Tensor. Then, if `self.dtype` is None, check + if the input tensor is or can be directly cast into a float tensor. + If `self.dtype` is not None, check if the input tensor's dtype is `self.dtype`. + """ + self.checktensor(value, 'input value of bijector') + value_type = self.dtype_base(value) + if self.dtype is None: + if self.issubclass_base(value_type, mstype.float_): + return value + return raise_type_error('input value of bijector', value_type, mstype.float_) + dtype_tensor = self.fill_base(self.dtype, self.shape_base(value), 0.0) + self.sametypeshape_base(value, dtype_tensor) + return value + + def _shape_mapping(self, shape): + shape_tensor = self.fill_base(self.parameter_type, shape, 0.0) + dist_shape_tensor = self.fill_base(self.parameter_type, self.batch_shape, 0.0) + return (shape_tensor + dist_shape_tensor).shape + + def shape_mapping(self, shape): + return self._shape_mapping(shape) + def _add_parameter(self, value, name): """ Cast `value` to a tensor and add it to `self.default_parameters`. @@ -98,26 +153,51 @@ class Bijector(Cell): if not hasattr(self, 'default_parameters'): self.default_parameters = [] self.parameter_names = [] + self.common_dtype = None # cast value to a tensor if it is not None - value_t = None if value is None else cast_to_tensor(value, self.parameter_type) - self.default_parameters += [value_t,] + if isinstance(value, bool) or value is None: + raise TypeError(f"{name} cannot be type {type(value)}") + value_t = Tensor(value) + # if the bijector's dtype is not specified + if self.dtype is None: + if self.common_dtype is None: + self.common_dtype = value_t.dtype + elif value_t.dtype != self.common_dtype: + raise TypeError(f"{name} should have the same dtype as other arguments.") + # check if the dtype of the input_parameter agrees with the bijector's dtype + elif value_t.dtype != self.dtype: + raise TypeError(f"{name} should have the same dtype as the bijector's dtype.") + self.default_parameters += [value,] self.parameter_names += [name,] return value_t - def _calc_event_shape(self): + def _calc_batch_shape(self): """ - Calculate event_shape based on parameters. + Calculate batch_shape based on parameters. """ - broadcast_shape = None - for param in self.default_parameters: - if broadcast_shape is None: - broadcast_shape = self.shape_base(param) - broadcast_shape_tensor = self.fill_base(self.parameter_type, broadcast_shape, 0.0) + param_dict = self.parameters['param_dict'] + broadcast_shape_tensor = None + for value in param_dict.values(): + if value is None: + return None + if broadcast_shape_tensor is None: + broadcast_shape_tensor = cast_to_tensor(value) else: - broadcast_shape = self.shape_base(param + broadcast_shape_tensor) - broadcast_shape_tensor = self.fill_base(self.parameter_type, broadcast_shape, 0.0) - return broadcast_shape + value = cast_to_tensor(value) + broadcast_shape_tensor = (value + broadcast_shape_tensor) + return broadcast_shape_tensor.shape + def _check_is_scalar_batch(self): + """ + Check if the parameters used during initialization are scalars. + """ + param_dict = self.parameters['param_dict'] + for value in param_dict.values(): + if value is None: + continue + if not isinstance(value, (int, float)): + return False + return True def _check_value(self, value, name): """ @@ -127,32 +207,35 @@ class Bijector(Cell): return value def cast_param_by_value(self, value, para): + """ + Cast the parameter(s) of the bijector to be the same type of input_value. + """ local = self.cast_base(para, self.dtype_base(value)) return local - def forward(self, *args, **kwargs): + def forward(self, value, *args, **kwargs): """ Forward transformation: transform the input value to another distribution. """ - return self._forward(*args, **kwargs) + return self._forward(value, *args, **kwargs) - def inverse(self, *args, **kwargs): + def inverse(self, value, *args, **kwargs): """ Inverse transformation: transform the input value back to the original distribution. """ - return self._inverse(*args, **kwargs) + return self._inverse(value, *args, **kwargs) - def forward_log_jacobian(self, *args, **kwargs): + def forward_log_jacobian(self, value, *args, **kwargs): """ Logarithm of the derivative of the forward transformation. """ - return self._forward_log_jacobian(*args, **kwargs) + return self._forward_log_jacobian(value, *args, **kwargs) - def inverse_log_jacobian(self, *args, **kwargs): + def inverse_log_jacobian(self, value, *args, **kwargs): """ Logarithm of the derivative of the inverse transformation. """ - return self._inverse_log_jacobian(*args, **kwargs) + return self._inverse_log_jacobian(value, *args, **kwargs) def __call__(self, *args, **kwargs): """ @@ -167,7 +250,7 @@ class Bijector(Cell): *args: args[0] shall be either a distribution or the name of a Bijector function. """ if isinstance(args[0], Distribution): - return TransformedDistribution(self, args[0], self.distribution.dtype) + return TransformedDistribution(self, args[0]) return super(Bijector, self).__call__(*args, **kwargs) def construct(self, name, *args, **kwargs): diff --git a/mindspore/nn/probability/bijector/gumbel_cdf.py b/mindspore/nn/probability/bijector/gumbel_cdf.py index eef7affc6a0..a2f53048747 100644 --- a/mindspore/nn/probability/bijector/gumbel_cdf.py +++ b/mindspore/nn/probability/bijector/gumbel_cdf.py @@ -13,10 +13,8 @@ # limitations under the License. # ============================================================================ """GumbelCDF Bijector""" -from mindspore.common import dtype as mstype -from mindspore._checkparam import Validator from mindspore.ops import operations as P -from ..distribution._utils.utils import check_greater_zero, set_param_type +from ..distribution._utils.utils import check_greater_zero from ..distribution._utils.custom_ops import exp_generic, log_generic from .bijector import Bijector @@ -30,12 +28,11 @@ class GumbelCDF(Bijector): Y = \exp(-\exp(\frac{-(X - loc)}{scale})) Note: - For `reverse` and `reverse_log_jacobian`, input should be in range of (0, 1). + For `inverse` and `inverse_log_jacobian`, input should be in range of (0, 1). Args: - loc (int, float, list, numpy.ndarray, Tensor): The location. Default: 0.. - scale (int, float, list, numpy.ndarray, Tensor): The scale. Default: 1.0. - dtype (mindspore.dtype): Type of the distribution which the bijector operates on. Default: float32. + loc (float, list, numpy.ndarray, Tensor): The location. Default: 0.. + scale (float, list, numpy.ndarray, Tensor): The scale. Default: 1.0. name (str): The name of the Bijector. Default: 'Gumbel_CDF'. Examples: @@ -61,22 +58,18 @@ class GumbelCDF(Bijector): def __init__(self, loc=0.0, scale=1.0, - dtype=mstype.float32, name='GumbelCDF'): """ Constructor of GumbelCDF Bijector. """ param = dict(locals()) - valid_dtype = mstype.float_type + mstype.int_type + mstype.uint_type - Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) - parameter_type = set_param_type({'loc': loc, "scale": scale}, dtype) - super(GumbelCDF, self).__init__(name=name, dtype=dtype, param=param) + param['param_dict'] = {'loc': loc, 'scale': scale} + super(GumbelCDF, self).__init__(name=name, param=param) - self._parameter_type = parameter_type self._loc = self._add_parameter(loc, 'loc') self._scale = self._add_parameter(scale, 'scale') check_greater_zero(self._scale, "scale") - self._event_shape = self._calc_event_shape() + self.cast = P.Cast() self.exp = exp_generic @@ -91,38 +84,34 @@ class GumbelCDF(Bijector): def scale(self): return self._scale - @property - def event_shape(self): - return self._event_shape - - @property - def parameter_type(self): - return self._parameter_type - def extend_repr(self): - return f'loc = {self.loc}, scale = {self.scale}' - - def shape_mapping(self, shape): - return shape + if self.is_scalar_batch: + str_info = f'loc = {self.loc}, scale = {self.scale}' + else: + str_info = f'batch_shape = {self.batch_shape}' + return str_info def _forward(self, x): - x = self._check_value(x, 'value') - x = self.cast(x, self.parameter_type) - z = (x - self.loc) / self.scale + x = self._check_value_dtype(x) + loc_local = self.cast_param_by_value(x, self.loc) + scale_local = self.cast_param_by_value(x, self.scale) + z = (x - loc_local) / scale_local return self.exp(-self.exp(-z)) def _inverse(self, y): - y = self._check_value(y, 'value') - y = self.cast(y, self.parameter_type) - return self.loc - self.scale * self.log(-self.log(y)) + y = self._check_value_dtype(y) + loc_local = self.cast_param_by_value(y, self.loc) + scale_local = self.cast_param_by_value(y, self.scale) + return loc_local - scale_local * self.log(-self.log(y)) def _forward_log_jacobian(self, x): - x = self._check_value(x, 'value') - x = self.cast(x, self.parameter_type) - z = (x - self.loc) / self.scale - return -z - self.exp(-z) - self.log(self.scale) + x = self._check_value_dtype(x) + loc_local = self.cast_param_by_value(x, self.loc) + scale_local = self.cast_param_by_value(x, self.scale) + z = (x - loc_local) / scale_local + return -z - self.exp(-z) - self.log(scale_local) def _inverse_log_jacobian(self, y): - y = self._check_value(y, 'value') - y = self.cast(y, self.parameter_type) - return self.log(self.scale / (-1. * y * self.log(y))) + y = self._check_value_dtype(y) + scale_local = self.cast_param_by_value(y, self.scale) + return self.log(scale_local / (-1. * y * self.log(y))) diff --git a/mindspore/nn/probability/bijector/invert.py b/mindspore/nn/probability/bijector/invert.py index 17f8dbc27a0..3cb63a9b5e9 100644 --- a/mindspore/nn/probability/bijector/invert.py +++ b/mindspore/nn/probability/bijector/invert.py @@ -53,23 +53,17 @@ class Invert(Bijector): name = (name + bijector.name) if name == 'Invert' else name super(Invert, self).__init__(is_constant_jacobian=bijector.is_constant_jacobian, is_injective=bijector.is_injective, - dtype=bijector.dtype, name=name, + dtype=bijector.dtype, param=param) self._bijector = bijector - if hasattr(self._bijector, 'event_shape'): - self._event_shape = self.bijector.event_shape - else: - self._event_shape = () + self._batch_shape = self.bijector.batch_shape + self._is_scalar_batch = self.bijector.is_scalar_batch @property def bijector(self): return self._bijector - @property - def event_shape(self): - return self._event_shape - def inverse(self, y): return self.bijector("forward", y) diff --git a/mindspore/nn/probability/bijector/power_transform.py b/mindspore/nn/probability/bijector/power_transform.py index 58c2b5bc54f..9a2ba0e2643 100644 --- a/mindspore/nn/probability/bijector/power_transform.py +++ b/mindspore/nn/probability/bijector/power_transform.py @@ -14,8 +14,7 @@ # ============================================================================ """Power Bijector""" from mindspore.ops import operations as P -from mindspore._checkparam import Validator as validator -from mindspore._checkparam import Rel +from ..distribution._utils.utils import check_greater_equal_zero from ..distribution._utils.custom_ops import exp_generic, expm1_generic, log_generic, log1p_generic from .bijector import Bijector @@ -37,7 +36,7 @@ class PowerTransform(Bijector): ValueError: When the power is less than 0 or is not known statically. Args: - power (int or float): The scale factor. Default: 0. + power (float, list, numpy.ndarray, Tensor): The scale factor. Default: 0. name (str): The name of the bijector. Default: 'PowerTransform'. Examples: @@ -64,10 +63,11 @@ class PowerTransform(Bijector): power=0, name='PowerTransform'): param = dict(locals()) + param['param_dict'] = {'power': power} super(PowerTransform, self).__init__(name=name, param=param) - validator.check_value_type('power', power, [int, float], self.name) - validator.check_number("power", power, 0, Rel.GE, self.name) - self._power = power + self._power = self._add_parameter(power, 'power') + check_greater_equal_zero(self._power, 'Power') + self.pow = P.Pow() self.dtypeop = P.DType() self.cast = P.Cast() @@ -81,13 +81,15 @@ class PowerTransform(Bijector): return self._power def extend_repr(self): - return f'power = {self.power}' + if self.is_scalar_batch: + str_info = f'power = {self.power}' + else: + str_info = f'batch_shape = {self.batch_shape}' + return str_info - def shape_mapping(self, shape): - return shape def _forward(self, x): - x = self._check_value(x, 'value') + x = self._check_value_dtype(x) power_local = self.cast_param_by_value(x, self.power) if power_local == 0: forward_v = self.exp(x) @@ -96,7 +98,7 @@ class PowerTransform(Bijector): return forward_v def _inverse(self, y): - y = self._check_value(y, 'value') + y = self._check_value_dtype(y) power_local = self.cast_param_by_value(y, self.power) if power_local == 0: inverse_v = self.log(y) @@ -116,7 +118,7 @@ class PowerTransform(Bijector): f'(x) = e^\frac{\log(xc + 1)}{c} * \frac{1}{xc + 1} \log(f'(x)) = (\frac{1}{c} - 1) * \log(xc + 1) """ - x = self._check_value(x, 'value') + x = self._check_value_dtype(x) power_local = self.cast_param_by_value(x, self.power) if power_local == 0: forward_log_j = x @@ -136,7 +138,7 @@ class PowerTransform(Bijector): f'(x) = \frac{e^c\log(y)}{y} \log(f'(x)) = \log(\frac{e^c\log(y)}{y}) = (c-1) * \log(y) """ - y = self._check_value(y, 'value') + y = self._check_value_dtype(y) power_local = self.cast_param_by_value(y, self.power) inverse_log_j = (power_local - 1) * self.log(y) return inverse_log_j diff --git a/mindspore/nn/probability/bijector/scalar_affine.py b/mindspore/nn/probability/bijector/scalar_affine.py index f6ab8b8a806..94f7865daa3 100644 --- a/mindspore/nn/probability/bijector/scalar_affine.py +++ b/mindspore/nn/probability/bijector/scalar_affine.py @@ -14,8 +14,6 @@ # ============================================================================ """Scalar Affine Bijector""" from mindspore.ops import operations as P -from mindspore._checkparam import Validator as validator -from ..distribution._utils.utils import cast_to_tensor from ..distribution._utils.custom_ops import log_generic from .bijector import Bijector @@ -30,10 +28,14 @@ class ScalarAffine(Bijector): where a is the scale factor and b is the shift factor. Args: - scale (float): The scale factor. Default: 1.0. - shift (float): The shift factor. Default: 0.0. + scale (float, list, numpy.ndarray, Tensor): The scale factor. Default: 1.0. + shift (float, list, numpy.ndarray, Tensor): The shift factor. Default: 0.0. name (str): The name of the bijector. Default: 'ScalarAffine'. + Note: + If `shift`, `scale` are passed in as numpy.ndarray or tensor, they have to have + the same dtype otherwise an error will be raised. + Examples: >>> # To initialize a ScalarAffine bijector of scale 1 and shift 2. >>> scalaraffine = nn.probability.bijector.ScalarAffine(1, 2) @@ -61,10 +63,7 @@ class ScalarAffine(Bijector): Constructor of ScalarAffine Bijector. """ param = dict(locals()) - validator.check_value_type( - 'scale', scale, [int, float], type(self).__name__) - validator.check_value_type( - 'shift', shift, [int, float], type(self).__name__) + param['param_dict'] = {'scale': scale, 'shift': shift} super(ScalarAffine, self).__init__( is_constant_jacobian=True, is_injective=True, @@ -72,8 +71,8 @@ class ScalarAffine(Bijector): dtype=None, param=param) - self._scale = cast_to_tensor(scale) - self._shift = cast_to_tensor(shift) + self._scale = self._add_parameter(scale, 'scale') + self._shift = self._add_parameter(shift, 'shift') self.abs = P.Abs() self.oneslike = P.OnesLike() @@ -90,17 +89,18 @@ class ScalarAffine(Bijector): return self._shift def extend_repr(self): - return f'scale = {self.scale}, shift = {self.shift}' - - def shape_mapping(self, shape): - return shape + if self.is_scalar_batch: + str_info = f'scale = {self.scale}, shift = {self.shift}' + else: + str_info = f'batch_shape = {self.batch_shape}' + return str_info def _forward(self, x): r""" .. math:: f(x) = a * x + b """ - x = self._check_value(x, 'value') + x = self._check_value_dtype(x) scale_local = self.cast_param_by_value(x, self.scale) shift_local = self.cast_param_by_value(x, self.shift) forward_v = scale_local * x + shift_local * self.oneslike(x) @@ -111,7 +111,7 @@ class ScalarAffine(Bijector): .. math:: f(y) = \frac{y - b}{a} """ - y = self._check_value(y, 'value') + y = self._check_value_dtype(y) scale_local = self.cast_param_by_value(y, self.scale) shift_local = self.cast_param_by_value(y, self.shift) inverse_v = (y - shift_local) / scale_local @@ -124,7 +124,7 @@ class ScalarAffine(Bijector): f'(x) = a \log(f'(x)) = \log(a) """ - x = self._check_value(x, 'value') + x = self._check_value_dtype(x) scale_local = self.cast_param_by_value(x, self.scale) forward_log_j = self.log(self.abs(scale_local)) return forward_log_j @@ -136,7 +136,7 @@ class ScalarAffine(Bijector): f'(x) = \frac{1.0}{a} \log(f'(x)) = - \log(a) """ - y = self._check_value(y, 'value') + y = self._check_value_dtype(y) scale_local = self.cast_param_by_value(y, self.scale) inverse_log_j = -1. * self.log(self.abs(scale_local)) return inverse_log_j diff --git a/mindspore/nn/probability/bijector/softplus.py b/mindspore/nn/probability/bijector/softplus.py index 93184ddde36..17aa498acae 100644 --- a/mindspore/nn/probability/bijector/softplus.py +++ b/mindspore/nn/probability/bijector/softplus.py @@ -16,8 +16,6 @@ import numpy as np from mindspore.ops import operations as P from mindspore.nn.layer.activation import LogSigmoid -from mindspore._checkparam import Validator as validator -from ..distribution._utils.utils import cast_to_tensor from ..distribution._utils.custom_ops import exp_generic, expm1_generic, log_generic from .bijector import Bijector @@ -32,7 +30,7 @@ class Softplus(Bijector): where k is the sharpness factor. Args: - sharpness (float): The scale factor. Default: 1.0. + sharpness (float, list, numpy.ndarray, Tensor): The scale factor. Default: 1.0. name (str): The name of the Bijector. Default: 'Softplus'. Examples: @@ -61,10 +59,9 @@ class Softplus(Bijector): Constructor of Softplus Bijector. """ param = dict(locals()) - validator.check_value_type('sharpness', sharpness, - [int, float], type(self).__name__) - super(Softplus, self).__init__(name=name, param=param) - self._sharpness = cast_to_tensor(sharpness) + param['param_dict'] = {'sharpness': sharpness} + super(Softplus, self).__init__(name=name, dtype=None, param=param) + self._sharpness = self._add_parameter(sharpness, 'sharpness') self.exp = exp_generic self.log = log_generic @@ -118,13 +115,14 @@ class Softplus(Bijector): return self._sharpness def extend_repr(self): - return f'sharpness = {self.sharpness}' - - def shape_mapping(self, shape): - return shape + if self.is_scalar_batch: + str_info = f'sharpness = {self.sharpness}' + else: + str_info = f'batch_shape = {self.batch_shape}' + return str_info def _forward(self, x): - x = self._check_value(x, 'value') + x = self._check_value_dtype(x) sharpness_local = self.cast_param_by_value(x, self.sharpness) scaled_value = sharpness_local * x forward_v = self.softplus(scaled_value) / sharpness_local @@ -136,7 +134,7 @@ class Softplus(Bijector): f(x) = \frac{\log(1 + e^{kx}))}{k} f^{-1}(y) = \frac{\log(e^{ky} - 1)}{k} """ - y = self._check_value(y, 'value') + y = self._check_value_dtype(y) sharpness_local = self.cast_param_by_value(y, self.sharpness) scaled_value = sharpness_local * y inverse_v = self.inverse_softplus(scaled_value) / sharpness_local @@ -149,7 +147,7 @@ class Softplus(Bijector): f'(x) = \frac{e^{kx}}{ 1 + e^{kx}} \log(f'(x)) = kx - \log(1 + e^{kx}) = kx - f(kx) """ - x = self._check_value(x, 'value') + x = self._check_value_dtype(x) sharpness_local = self.cast_param_by_value(x, self.sharpness) scaled_value = sharpness_local * x forward_log_j = self.log_sigmoid(scaled_value) @@ -162,7 +160,7 @@ class Softplus(Bijector): f'(y) = \frac{e^{ky}}{e^{ky} - 1} \log(f'(y)) = ky - \log(e^{ky} - 1) = ky - f(ky) """ - y = self._check_value(y, 'value') + y = self._check_value_dtype(y) sharpness_local = self.cast_param_by_value(y, self.sharpness) scaled_value = sharpness_local * y inverse_log_j = scaled_value - self.inverse_softplus(scaled_value) diff --git a/mindspore/nn/probability/distribution/_utils/utils.py b/mindspore/nn/probability/distribution/_utils/utils.py index c344e9c06f4..8e269fbd56e 100644 --- a/mindspore/nn/probability/distribution/_utils/utils.py +++ b/mindspore/nn/probability/distribution/_utils/utils.py @@ -229,6 +229,11 @@ def raise_not_implemented_util(func_name, obj, *args, **kwargs): raise NotImplementedError( f"{func_name} is not implemented for {obj} distribution.") +@constexpr +def raise_type_error(name, cur_type, required_type): + raise TypeError( + f"For {name} , the type should be or be subclass of {required_type}, but got {cur_type}") + @constexpr def check_distribution_name(name, expected_name): @@ -304,7 +309,7 @@ def set_param_type(args, hint_type): TypeError: if tensors in args are not the same dtype. """ int_type = mstype.int_type + mstype.uint_type - if hint_type in int_type: + if hint_type in int_type or hint_type is None: hint_type = mstype.float32 common_dtype = None for name, arg in args.items(): diff --git a/mindspore/nn/probability/distribution/distribution.py b/mindspore/nn/probability/distribution/distribution.py index d218d885374..8dfda69af8f 100644 --- a/mindspore/nn/probability/distribution/distribution.py +++ b/mindspore/nn/probability/distribution/distribution.py @@ -72,13 +72,12 @@ class Distribution(Cell): if not(k == 'self' or k.startswith('_')): self._parameters[k] = param[k] - # some attributes - if 'distribution' in self.parameters.keys(): - self.parameter_type = self.parameters['distribution'].parameter_type - else: + # if not a transformed distribution, set the following attribute + if 'distribution' not in self.parameters.keys(): self.parameter_type = set_param_type(self.parameters['param_dict'], dtype) - self._broadcast_shape = self._calc_broadcast_shape() - self._is_scalar_batch = self._check_is_scalar_batch() + self._batch_shape = self._calc_batch_shape() + self._is_scalar_batch = self._check_is_scalar_batch() + self._broadcast_shape = self._batch_shape # set the function to call according to the derived class's attributes self._set_prob() @@ -128,6 +127,10 @@ class Distribution(Cell): def is_scalar_batch(self): return self._is_scalar_batch + @property + def batch_shape(self): + return self._batch_shape + @property def broadcast_shape(self): return self._broadcast_shape @@ -208,8 +211,6 @@ class Distribution(Cell): """ Check if the parameters used during initialization are scalars. """ - if 'distribution' in self.parameters.keys(): - return self.parameters['distribution'].is_scalar_batch param_dict = self.parameters['param_dict'] for value in param_dict.values(): if value is None: @@ -218,12 +219,10 @@ class Distribution(Cell): return False return True - def _calc_broadcast_shape(self): + def _calc_batch_shape(self): """ Calculate the broadcast shape of the parameters used during initialization. """ - if 'distribution' in self.parameters.keys(): - return self.parameters['distribution'].broadcast_shape param_dict = self.parameters['param_dict'] broadcast_shape_tensor = None for value in param_dict.values(): @@ -362,14 +361,14 @@ class Distribution(Cell): """ return self._get_dist_args(*args, **kwargs) - def _get_dist_type(self, *args, **kwargs): - return raise_not_implemented_util('get_dist_type', self.name, *args, **kwargs) + def _get_dist_type(self): + return raise_not_implemented_util('get_dist_type', self.name) - def get_dist_type(self, *args, **kwargs): + def get_dist_type(self): """ Return the type of the distribution. """ - return self._get_dist_type(*args, **kwargs) + return self._get_dist_type() def _raise_not_implemented_error(self, func_name): name = self.name @@ -751,5 +750,5 @@ class Distribution(Cell): if name == 'get_dist_args': return self._get_dist_args(*args, **kwargs) if name == 'get_dist_type': - return self._get_dist_type(*args, **kwargs) + return self._get_dist_type() return raise_not_implemented_util(name, self.name, *args, **kwargs) diff --git a/mindspore/nn/probability/distribution/gumbel.py b/mindspore/nn/probability/distribution/gumbel.py index f598e29d55e..50d891ee82d 100644 --- a/mindspore/nn/probability/distribution/gumbel.py +++ b/mindspore/nn/probability/distribution/gumbel.py @@ -103,17 +103,12 @@ class Gumbel(TransformedDistribution): """ valid_dtype = mstype.float_type Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) - gumbel_cdf = msb.GumbelCDF(loc, scale, dtype) + gumbel_cdf = msb.GumbelCDF(loc, scale) super(Gumbel, self).__init__( distribution=msd.Uniform(0.0, 1.0, dtype=dtype), bijector=msb.Invert(gumbel_cdf), seed=seed, name=name) - self.parameter_type = gumbel_cdf.parameter_type - self._broadcast_shape = gumbel_cdf.event_shape - if self._broadcast_shape != (): - self._is_scalar_batch = False - # overwrite default_parameters and parameter_names self._reset_parameters() self._loc = self._add_parameter(loc, 'loc') @@ -202,6 +197,7 @@ class Gumbel(TransformedDistribution): where z = \frac{x - loc}{scale} """ value = self._check_value(value, 'value') + value = self.cast(value, self.dtype) z = (value - self.loc) / self.scale return -(z + self.exp(-z)) - self.log(self.scale) @@ -210,6 +206,8 @@ class Gumbel(TransformedDistribution): .. math:: cdf_pdf(X) = \exp(-\exp(-\frac{x - loc}{scale}) """ + value = self._check_value(value, 'value') + value = self.cast(value, self.dtype) return self._gumbel_bijector("forward", value) def _cross_entropy(self, dist, loc_b, scale_b): @@ -251,12 +249,14 @@ class Gumbel(TransformedDistribution): self.expm1((loc_b - self.loc) / scale_b + self.lgamma(self.scale / scale_b + 1.)) def _sample(self, shape=()): + shape = self.checktuple(shape, 'shape') origin_shape = shape + self._broadcast_shape if origin_shape == (): sample_shape = (1,) else: sample_shape = origin_shape org_sample = self.distribution("sample", sample_shape) + org_sample = self.cast(org_sample, self.dtype) value = self.bijector("forward", org_sample) if origin_shape == (): value = self.squeeze(value) diff --git a/mindspore/nn/probability/distribution/log_normal.py b/mindspore/nn/probability/distribution/log_normal.py index 64bc160ff2e..c82e79f75c9 100644 --- a/mindspore/nn/probability/distribution/log_normal.py +++ b/mindspore/nn/probability/distribution/log_normal.py @@ -137,6 +137,11 @@ class LogNormal(msd.TransformedDistribution): bijector=msb.Exp(), seed=seed, name=name) + # overwrite default_parameters and parameter_names + self._reset_parameters() + self._loc = self._add_parameter(loc, 'loc') + self._scale = self._add_parameter(scale, 'scale') + self.log_2pi = np.log(2 * np.pi) #ops needed for the class @@ -154,12 +159,12 @@ class LogNormal(msd.TransformedDistribution): @property def loc(self): """Distribution parameter for the pre-transformed mean.""" - return self.distribution("mean") + return self._loc @property def scale(self): """Distribution parameter for the pre-transformed standard deviation.""" - return self.distribution("sd") + return self._scale def _get_dist_type(self): return "LogNormal" @@ -168,18 +173,18 @@ class LogNormal(msd.TransformedDistribution): if loc is not None: self.checktensor(loc, 'loc') else: - loc = self.distribution("mean") + loc = self.loc if scale is not None: self.checktensor(scale, 'scale') else: - scale = self.distribution("sd") + scale = self.scale return loc, scale def extend_repr(self): if self.is_scalar_batch: - s = f'loc = {self._mean_value}, scale = {self._sd_value}' + s = f'loc = {self.loc}, scale = {self.scale}' else: - s = f'batch_shape = {self._broadcast_shape}' + s = f'batch_shape = {self.broadcast_shape}' return s def _mean(self, loc=None, scale=None): diff --git a/mindspore/nn/probability/distribution/transformed_distribution.py b/mindspore/nn/probability/distribution/transformed_distribution.py index 1bcc77781df..927420291c9 100644 --- a/mindspore/nn/probability/distribution/transformed_distribution.py +++ b/mindspore/nn/probability/distribution/transformed_distribution.py @@ -16,6 +16,7 @@ import numpy as np from mindspore._checkparam import Validator as validator from mindspore.ops import operations as P +from mindspore.common import dtype as mstype import mindspore.nn as nn from .distribution import Distribution from ._utils.utils import raise_not_impl_error @@ -30,7 +31,7 @@ class TransformedDistribution(Distribution): Args: bijector (Bijector): The transformation to perform. - distribution (Distribution): The original distribution. + distribution (Distribution): The original distribution. Must has dtype of mindspore.float_type. seed (int): The seed is used in sampling. The global seed is used if it is None. Default:None. If this seed is given when a TransformedDistribution object is initialised, the object's sampling function will use this seed; elsewise, the underlying distribution's seed will be used. @@ -40,6 +41,12 @@ class TransformedDistribution(Distribution): The arguments used to initialize the original distribution cannot be None. For example, mynormal = nn.Normal(dtype=dtyple.float32) cannot be used to initialized a TransformedDistribution since `mean` and `sd` are not specified. + `batch_shape` is the batch_shape of the original distribution. + `broadcast_shape` is the broadcast shape between the original distribution and bijector. + `is_scalar_batch` is only true if both the original distribution and the bijector are scalar batches. + `default_parameters`, `parameter_names` and `parameter_type` are set to be consistent with the original + distribution. Derived class can overwrite `default_parameters` and `parameter_names` by calling + `reset_parameters` followed by `add_parameter`. Examples: >>> # To initialize a transformed distribution, e.g. a lognormal distribution, @@ -75,28 +82,34 @@ class TransformedDistribution(Distribution): [nn.probability.bijector.Bijector], type(self).__name__) validator.check_value_type('distribution', distribution, [Distribution], type(self).__name__) + validator.check_type_name("dtype", distribution.dtype, mstype.float_type, type(self).__name__) super(TransformedDistribution, self).__init__(seed, distribution.dtype, name, param) self._bijector = bijector self._distribution = distribution - self._is_linear_transformation = bijector.is_constant_jacobian - self.default_parameters = distribution.default_parameters - self.parameter_names = distribution.parameter_names + + # set attributes + self._is_linear_transformation = self.bijector.is_constant_jacobian + self._dtype = self.distribution.dtype + self._is_scalar_batch = self.distribution.is_scalar_batch and self.bijector.is_scalar_batch + self._batch_shape = self.distribution.batch_shape + + self.default_parameters = self.distribution.default_parameters + self.parameter_names = self.distribution.parameter_names + # by default, set the parameter_type to be the distribution's parameter_type + self.parameter_type = self.distribution.parameter_type self.exp = exp_generic self.log = log_generic self.isnan = P.IsNan() + self.cast_base = P.Cast() self.equal_base = P.Equal() self.select_base = P.Select() - self.fill = P.Fill() + self.fill_base = P.Fill() + + # broadcast bijector batch_shape and distribution batch_shape + self._broadcast_shape = self._broadcast_bijector_dist() - # check if batch shape of the distribution and event shape is broadcastable - if hasattr(self.bijector, 'event_shape'): - event_shape_tensor = self.fill(self.dtype, self.bijector.event_shape, 0.0) - broadcast_shape_tensor = self.fill(self.dtype, self.broadcast_shape, 0.0) - self._batch_event = (event_shape_tensor + broadcast_shape_tensor).shape - else: - self._batch_event = self.broadcast_shape @property def bijector(self): @@ -108,12 +121,22 @@ class TransformedDistribution(Distribution): @property def dtype(self): - return self.distribution.dtype + return self._dtype @property def is_linear_transformation(self): return self._is_linear_transformation + def _broadcast_bijector_dist(self): + """ + check if the batch shape of base distribution and the bijector is broadcastable. + """ + if self.batch_shape is None or self.bijector.batch_shape is None: + return None + bijector_shape_tensor = self.fill_base(self.dtype, self.bijector.batch_shape, 0.0) + dist_shape_tensor = self.fill_base(self.dtype, self.batch_shape, 0.0) + return (bijector_shape_tensor + dist_shape_tensor).shape + def _cdf(self, value, *args, **kwargs): r""" .. math:: diff --git a/tests/ut/python/nn/probability/bijector/test_bijector.py b/tests/ut/python/nn/probability/bijector/test_bijector.py new file mode 100644 index 00000000000..f3ae926c5bf --- /dev/null +++ b/tests/ut/python/nn/probability/bijector/test_bijector.py @@ -0,0 +1,191 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test cases for exp""" +import numpy as np +import pytest + +import mindspore.nn as nn +import mindspore.nn.probability.bijector as msb +from mindspore import Tensor +from mindspore import dtype + +class MyBijector(msb.Bijector): + """ + Customized bijector class with dtype not specified. + """ + def __init__(self, param1, param2): + param = dict(locals()) + param['param_dict'] = {'param1': param1, 'param2': param2} + super(MyBijector, self).__init__(name='MyBijector', dtype=None, param=param) + + self._param1 = self._add_parameter(param1, 'param1') + self._param2 = self._add_parameter(param2, 'param2') + + @property + def param1(self): + return self._param1 + + @property + def param2(self): + return self._param2 + + def _forward(self, value): + value = self._check_value_dtype(value) + param1_local = self.cast_param_by_value(value, self.param1) + param2_local = self.cast_param_by_value(value, self.param2) + return value * param1_local + param2_local + +class MySecondBijector(msb.Bijector): + """ + Customized bijector class with dtype specified. + """ + def __init__(self, param1, param2): + param = dict(locals()) + param['param_dict'] = {'param1': param1, 'param2': param2} + super(MySecondBijector, self).__init__(name='MySecondBijector', dtype=dtype.float32, param=param) + + self._param1 = self._add_parameter(param1, 'param1') + self._param2 = self._add_parameter(param2, 'param2') + + @property + def param1(self): + return self._param1 + + @property + def param2(self): + return self._param2 + + def _forward(self, value): + value = self._check_value_dtype(value) + param1_local = self.cast_param_by_value(value, self.param1) + param2_local = self.cast_param_by_value(value, self.param2) + return value * param1_local + param2_local + +def test_arguments_same_type(): + """ + Test bijector initializations. + """ + param1_1 = np.array(1.0).astype(np.float16) + param2_1 = np.array(2.0).astype(np.float32) + with pytest.raises(TypeError): + MyBijector(param1_1, param2_1) + param1_2 = Tensor(1.0, dtype=dtype.float16) + param2_2 = Tensor(2.0, dtype=dtype.float32) + with pytest.raises(TypeError): + MyBijector(param1_2, param2_2) + with pytest.raises(TypeError): + MyBijector(True, param2_2) + with pytest.raises(TypeError): + MyBijector(None, param2_2) + param1_3 = Tensor(1.0, dtype=dtype.float32) + param2_3 = Tensor(2.0, dtype=dtype.float32) + bijector = MyBijector(param1_3, param2_3) + assert isinstance(bijector, msb.Bijector) + param1_4 = np.array([1.0, 2.0]).astype(np.float16) + param2_4 = np.array([1.0, 2.0]).astype(np.float16) + bijector = MyBijector(param1_4, param2_4) + assert isinstance(bijector, msb.Bijector) + bijector = MyBijector(1.0, 2.0) + assert isinstance(bijector, msb.Bijector) + +def test_arguments_with_dtype_specified(): + """ + Customized bijector class with dtype not specified. + """ + param1_1 = np.array(1.0).astype(np.float16) + param2_1 = np.array(2.0).astype(np.float16) + with pytest.raises(TypeError): + MySecondBijector(param1_1, param2_1) + param1_2 = Tensor(1.0, dtype=dtype.float16) + param2_2 = Tensor(2.0, dtype=dtype.float32) + with pytest.raises(TypeError): + MySecondBijector(param1_2, param2_2) + with pytest.raises(TypeError): + MySecondBijector(True, param2_2) + with pytest.raises(TypeError): + MySecondBijector(None, param2_2) + param1_3 = Tensor(1.0, dtype=dtype.float32) + param2_3 = Tensor(2.0, dtype=dtype.float32) + bijector = MyBijector(param1_3, param2_3) + assert isinstance(bijector, msb.Bijector) + param1_4 = np.array(2.0).astype(np.float32) + param2_4 = np.array(1.0).astype(np.float32) + bijector = MyBijector(param1_4, param2_4) + assert isinstance(bijector, msb.Bijector) + +class Net1(nn.Cell): + """ + Test input value when bijector's dtype is not specified. + """ + def __init__(self): + super(Net1, self).__init__() + self.bijector = MyBijector(np.array(1.0).astype(np.float32), np.array(2.0).astype(np.float32)) + + def construct(self, value): + return self.bijector.forward(value) + +class Net2(nn.Cell): + """ + Test input value when bijector's dtype is specified. + """ + def __init__(self): + super(Net2, self).__init__() + self.bijector = MySecondBijector(np.array(1.0).astype(np.float32), np.array(2.0).astype(np.float32)) + + def construct(self, value): + return self.bijector.forward(value) + +def test_input_value(): + """ + Test validity of input value. + """ + net = Net1() + value = None + with pytest.raises(TypeError): + ans = net(value) + value = 1.0 + with pytest.raises(TypeError): + ans = net(value) + value = Tensor(1.0, dtype=dtype.int32) + with pytest.raises(TypeError): + ans = net(value) + value = Tensor(1.0, dtype=dtype.float32) + ans = net(value) + assert ans.dtype == dtype.float32 + value = Tensor(1.0, dtype=dtype.float16) + ans = net(value) + assert ans.dtype == dtype.float16 + +def test_input_value2(): + """ + Test validity of input value. + """ + net = Net2() + value = None + with pytest.raises(TypeError): + ans = net(value) + value = 1.0 + with pytest.raises(TypeError): + ans = net(value) + value = Tensor(1.0, dtype=dtype.int32) + with pytest.raises(TypeError): + ans = net(value) + value = Tensor(1.0, dtype=dtype.float16) + with pytest.raises(TypeError): + ans = net(value) + value = Tensor(1.0, dtype=dtype.float32) + ans = net(value) + assert ans.dtype == dtype.float32 + \ No newline at end of file diff --git a/tests/ut/python/nn/probability/distribution/test_distribution.py b/tests/ut/python/nn/probability/distribution/test_distribution.py index bea14a4d65e..8542343c1a5 100644 --- a/tests/ut/python/nn/probability/distribution/test_distribution.py +++ b/tests/ut/python/nn/probability/distribution/test_distribution.py @@ -1,3 +1,7 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0