forked from mindspore-Ecosystem/mindspore
redesigned bijector classes, changed bacth_shape calculation in transformed distribution and upgraded dtype logic of bijector class
This commit is contained in:
parent
05643cb20c
commit
fb0263f869
|
@ -16,8 +16,10 @@
|
|||
from mindspore import context
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from ..distribution._utils.utils import CheckTensor, cast_to_tensor
|
||||
from ..distribution._utils.utils import CheckTensor, cast_to_tensor, raise_type_error
|
||||
from ..distribution import Distribution
|
||||
from ..distribution import TransformedDistribution
|
||||
|
||||
|
@ -32,6 +34,17 @@ class Bijector(Cell):
|
|||
name (str): The name of the Bijector. Default: None.
|
||||
dtype (mindspore.dtype): The type of the distributions that the Bijector can operate on. Default: None.
|
||||
param (dict): The parameters used to initialize the Bijector. Default: None.
|
||||
|
||||
Note:
|
||||
`dtype` of bijector represents the type of the distributions that the bijector could operate on.
|
||||
When `dtype` is None, there is no enforcement on the type of input value except that the input value
|
||||
has to be float type. During initilization, when `dtype` is None, there is no enforcement on the dtype
|
||||
of the parameters. All parameters should have the same float type, otherwise a TypeError will be raised.
|
||||
Specifically, the parameter type will follow the dtype of the input value, i.e. parameters of the bijector
|
||||
will be casted into the same type as input value when `dtype`is None.
|
||||
When `dtype` is specified, it is forcing the parameters and input value to be the same dtype as `dtype`.
|
||||
When the type of parameters or the type of the input value is not the same as `dtype`, a TypeError will be
|
||||
raised. Only subtype of mindspore.float_type can be used to specify bijector's `dtype`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -48,6 +61,8 @@ class Bijector(Cell):
|
|||
validator.check_value_type(
|
||||
'is_constant_jacobian', is_constant_jacobian, [bool], name)
|
||||
validator.check_value_type('is_injective', is_injective, [bool], name)
|
||||
if dtype is not None:
|
||||
validator.check_type_name("dtype", dtype, mstype.float_type, type(self).__name__)
|
||||
self._name = name
|
||||
self._dtype = dtype
|
||||
self._parameters = {}
|
||||
|
@ -57,6 +72,12 @@ class Bijector(Cell):
|
|||
continue
|
||||
if not(k == 'self' or k.startswith('_')):
|
||||
self._parameters[k] = param[k]
|
||||
|
||||
# if no bijector is used as an argument during initilization
|
||||
if 'bijector' not in param.keys():
|
||||
self._batch_shape = self._calc_batch_shape()
|
||||
self._is_scalar_batch = self._check_is_scalar_batch()
|
||||
|
||||
self._is_constant_jacobian = is_constant_jacobian
|
||||
self._is_injective = is_injective
|
||||
|
||||
|
@ -68,6 +89,8 @@ class Bijector(Cell):
|
|||
self.dtype_base = P.DType()
|
||||
self.shape_base = P.Shape()
|
||||
self.fill_base = P.Fill()
|
||||
self.sametypeshape_base = P.SameTypeShape()
|
||||
self.issubclass_base = P.IsSubClass()
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
|
@ -89,6 +112,38 @@ class Bijector(Cell):
|
|||
def is_injective(self):
|
||||
return self._is_injective
|
||||
|
||||
@property
|
||||
def batch_shape(self):
|
||||
return self._batch_shape
|
||||
|
||||
@property
|
||||
def is_scalar_batch(self):
|
||||
return self._is_scalar_batch
|
||||
|
||||
def _check_value_dtype(self, value):
|
||||
"""
|
||||
Firstly check if the input value is Tensor. Then, if `self.dtype` is None, check
|
||||
if the input tensor is or can be directly cast into a float tensor.
|
||||
If `self.dtype` is not None, check if the input tensor's dtype is `self.dtype`.
|
||||
"""
|
||||
self.checktensor(value, 'input value of bijector')
|
||||
value_type = self.dtype_base(value)
|
||||
if self.dtype is None:
|
||||
if self.issubclass_base(value_type, mstype.float_):
|
||||
return value
|
||||
return raise_type_error('input value of bijector', value_type, mstype.float_)
|
||||
dtype_tensor = self.fill_base(self.dtype, self.shape_base(value), 0.0)
|
||||
self.sametypeshape_base(value, dtype_tensor)
|
||||
return value
|
||||
|
||||
def _shape_mapping(self, shape):
|
||||
shape_tensor = self.fill_base(self.parameter_type, shape, 0.0)
|
||||
dist_shape_tensor = self.fill_base(self.parameter_type, self.batch_shape, 0.0)
|
||||
return (shape_tensor + dist_shape_tensor).shape
|
||||
|
||||
def shape_mapping(self, shape):
|
||||
return self._shape_mapping(shape)
|
||||
|
||||
def _add_parameter(self, value, name):
|
||||
"""
|
||||
Cast `value` to a tensor and add it to `self.default_parameters`.
|
||||
|
@ -98,26 +153,51 @@ class Bijector(Cell):
|
|||
if not hasattr(self, 'default_parameters'):
|
||||
self.default_parameters = []
|
||||
self.parameter_names = []
|
||||
self.common_dtype = None
|
||||
# cast value to a tensor if it is not None
|
||||
value_t = None if value is None else cast_to_tensor(value, self.parameter_type)
|
||||
self.default_parameters += [value_t,]
|
||||
if isinstance(value, bool) or value is None:
|
||||
raise TypeError(f"{name} cannot be type {type(value)}")
|
||||
value_t = Tensor(value)
|
||||
# if the bijector's dtype is not specified
|
||||
if self.dtype is None:
|
||||
if self.common_dtype is None:
|
||||
self.common_dtype = value_t.dtype
|
||||
elif value_t.dtype != self.common_dtype:
|
||||
raise TypeError(f"{name} should have the same dtype as other arguments.")
|
||||
# check if the dtype of the input_parameter agrees with the bijector's dtype
|
||||
elif value_t.dtype != self.dtype:
|
||||
raise TypeError(f"{name} should have the same dtype as the bijector's dtype.")
|
||||
self.default_parameters += [value,]
|
||||
self.parameter_names += [name,]
|
||||
return value_t
|
||||
|
||||
def _calc_event_shape(self):
|
||||
def _calc_batch_shape(self):
|
||||
"""
|
||||
Calculate event_shape based on parameters.
|
||||
Calculate batch_shape based on parameters.
|
||||
"""
|
||||
broadcast_shape = None
|
||||
for param in self.default_parameters:
|
||||
if broadcast_shape is None:
|
||||
broadcast_shape = self.shape_base(param)
|
||||
broadcast_shape_tensor = self.fill_base(self.parameter_type, broadcast_shape, 0.0)
|
||||
param_dict = self.parameters['param_dict']
|
||||
broadcast_shape_tensor = None
|
||||
for value in param_dict.values():
|
||||
if value is None:
|
||||
return None
|
||||
if broadcast_shape_tensor is None:
|
||||
broadcast_shape_tensor = cast_to_tensor(value)
|
||||
else:
|
||||
broadcast_shape = self.shape_base(param + broadcast_shape_tensor)
|
||||
broadcast_shape_tensor = self.fill_base(self.parameter_type, broadcast_shape, 0.0)
|
||||
return broadcast_shape
|
||||
value = cast_to_tensor(value)
|
||||
broadcast_shape_tensor = (value + broadcast_shape_tensor)
|
||||
return broadcast_shape_tensor.shape
|
||||
|
||||
def _check_is_scalar_batch(self):
|
||||
"""
|
||||
Check if the parameters used during initialization are scalars.
|
||||
"""
|
||||
param_dict = self.parameters['param_dict']
|
||||
for value in param_dict.values():
|
||||
if value is None:
|
||||
continue
|
||||
if not isinstance(value, (int, float)):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _check_value(self, value, name):
|
||||
"""
|
||||
|
@ -127,32 +207,35 @@ class Bijector(Cell):
|
|||
return value
|
||||
|
||||
def cast_param_by_value(self, value, para):
|
||||
"""
|
||||
Cast the parameter(s) of the bijector to be the same type of input_value.
|
||||
"""
|
||||
local = self.cast_base(para, self.dtype_base(value))
|
||||
return local
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
def forward(self, value, *args, **kwargs):
|
||||
"""
|
||||
Forward transformation: transform the input value to another distribution.
|
||||
"""
|
||||
return self._forward(*args, **kwargs)
|
||||
return self._forward(value, *args, **kwargs)
|
||||
|
||||
def inverse(self, *args, **kwargs):
|
||||
def inverse(self, value, *args, **kwargs):
|
||||
"""
|
||||
Inverse transformation: transform the input value back to the original distribution.
|
||||
"""
|
||||
return self._inverse(*args, **kwargs)
|
||||
return self._inverse(value, *args, **kwargs)
|
||||
|
||||
def forward_log_jacobian(self, *args, **kwargs):
|
||||
def forward_log_jacobian(self, value, *args, **kwargs):
|
||||
"""
|
||||
Logarithm of the derivative of the forward transformation.
|
||||
"""
|
||||
return self._forward_log_jacobian(*args, **kwargs)
|
||||
return self._forward_log_jacobian(value, *args, **kwargs)
|
||||
|
||||
def inverse_log_jacobian(self, *args, **kwargs):
|
||||
def inverse_log_jacobian(self, value, *args, **kwargs):
|
||||
"""
|
||||
Logarithm of the derivative of the inverse transformation.
|
||||
"""
|
||||
return self._inverse_log_jacobian(*args, **kwargs)
|
||||
return self._inverse_log_jacobian(value, *args, **kwargs)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""
|
||||
|
@ -167,7 +250,7 @@ class Bijector(Cell):
|
|||
*args: args[0] shall be either a distribution or the name of a Bijector function.
|
||||
"""
|
||||
if isinstance(args[0], Distribution):
|
||||
return TransformedDistribution(self, args[0], self.distribution.dtype)
|
||||
return TransformedDistribution(self, args[0])
|
||||
return super(Bijector, self).__call__(*args, **kwargs)
|
||||
|
||||
def construct(self, name, *args, **kwargs):
|
||||
|
|
|
@ -13,10 +13,8 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""GumbelCDF Bijector"""
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore._checkparam import Validator
|
||||
from mindspore.ops import operations as P
|
||||
from ..distribution._utils.utils import check_greater_zero, set_param_type
|
||||
from ..distribution._utils.utils import check_greater_zero
|
||||
from ..distribution._utils.custom_ops import exp_generic, log_generic
|
||||
from .bijector import Bijector
|
||||
|
||||
|
@ -30,12 +28,11 @@ class GumbelCDF(Bijector):
|
|||
Y = \exp(-\exp(\frac{-(X - loc)}{scale}))
|
||||
|
||||
Note:
|
||||
For `reverse` and `reverse_log_jacobian`, input should be in range of (0, 1).
|
||||
For `inverse` and `inverse_log_jacobian`, input should be in range of (0, 1).
|
||||
|
||||
Args:
|
||||
loc (int, float, list, numpy.ndarray, Tensor): The location. Default: 0..
|
||||
scale (int, float, list, numpy.ndarray, Tensor): The scale. Default: 1.0.
|
||||
dtype (mindspore.dtype): Type of the distribution which the bijector operates on. Default: float32.
|
||||
loc (float, list, numpy.ndarray, Tensor): The location. Default: 0..
|
||||
scale (float, list, numpy.ndarray, Tensor): The scale. Default: 1.0.
|
||||
name (str): The name of the Bijector. Default: 'Gumbel_CDF'.
|
||||
|
||||
Examples:
|
||||
|
@ -61,22 +58,18 @@ class GumbelCDF(Bijector):
|
|||
def __init__(self,
|
||||
loc=0.0,
|
||||
scale=1.0,
|
||||
dtype=mstype.float32,
|
||||
name='GumbelCDF'):
|
||||
"""
|
||||
Constructor of GumbelCDF Bijector.
|
||||
"""
|
||||
param = dict(locals())
|
||||
valid_dtype = mstype.float_type + mstype.int_type + mstype.uint_type
|
||||
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
|
||||
parameter_type = set_param_type({'loc': loc, "scale": scale}, dtype)
|
||||
super(GumbelCDF, self).__init__(name=name, dtype=dtype, param=param)
|
||||
param['param_dict'] = {'loc': loc, 'scale': scale}
|
||||
super(GumbelCDF, self).__init__(name=name, param=param)
|
||||
|
||||
self._parameter_type = parameter_type
|
||||
self._loc = self._add_parameter(loc, 'loc')
|
||||
self._scale = self._add_parameter(scale, 'scale')
|
||||
check_greater_zero(self._scale, "scale")
|
||||
self._event_shape = self._calc_event_shape()
|
||||
|
||||
|
||||
self.cast = P.Cast()
|
||||
self.exp = exp_generic
|
||||
|
@ -91,38 +84,34 @@ class GumbelCDF(Bijector):
|
|||
def scale(self):
|
||||
return self._scale
|
||||
|
||||
@property
|
||||
def event_shape(self):
|
||||
return self._event_shape
|
||||
|
||||
@property
|
||||
def parameter_type(self):
|
||||
return self._parameter_type
|
||||
|
||||
def extend_repr(self):
|
||||
return f'loc = {self.loc}, scale = {self.scale}'
|
||||
|
||||
def shape_mapping(self, shape):
|
||||
return shape
|
||||
if self.is_scalar_batch:
|
||||
str_info = f'loc = {self.loc}, scale = {self.scale}'
|
||||
else:
|
||||
str_info = f'batch_shape = {self.batch_shape}'
|
||||
return str_info
|
||||
|
||||
def _forward(self, x):
|
||||
x = self._check_value(x, 'value')
|
||||
x = self.cast(x, self.parameter_type)
|
||||
z = (x - self.loc) / self.scale
|
||||
x = self._check_value_dtype(x)
|
||||
loc_local = self.cast_param_by_value(x, self.loc)
|
||||
scale_local = self.cast_param_by_value(x, self.scale)
|
||||
z = (x - loc_local) / scale_local
|
||||
return self.exp(-self.exp(-z))
|
||||
|
||||
def _inverse(self, y):
|
||||
y = self._check_value(y, 'value')
|
||||
y = self.cast(y, self.parameter_type)
|
||||
return self.loc - self.scale * self.log(-self.log(y))
|
||||
y = self._check_value_dtype(y)
|
||||
loc_local = self.cast_param_by_value(y, self.loc)
|
||||
scale_local = self.cast_param_by_value(y, self.scale)
|
||||
return loc_local - scale_local * self.log(-self.log(y))
|
||||
|
||||
def _forward_log_jacobian(self, x):
|
||||
x = self._check_value(x, 'value')
|
||||
x = self.cast(x, self.parameter_type)
|
||||
z = (x - self.loc) / self.scale
|
||||
return -z - self.exp(-z) - self.log(self.scale)
|
||||
x = self._check_value_dtype(x)
|
||||
loc_local = self.cast_param_by_value(x, self.loc)
|
||||
scale_local = self.cast_param_by_value(x, self.scale)
|
||||
z = (x - loc_local) / scale_local
|
||||
return -z - self.exp(-z) - self.log(scale_local)
|
||||
|
||||
def _inverse_log_jacobian(self, y):
|
||||
y = self._check_value(y, 'value')
|
||||
y = self.cast(y, self.parameter_type)
|
||||
return self.log(self.scale / (-1. * y * self.log(y)))
|
||||
y = self._check_value_dtype(y)
|
||||
scale_local = self.cast_param_by_value(y, self.scale)
|
||||
return self.log(scale_local / (-1. * y * self.log(y)))
|
||||
|
|
|
@ -53,23 +53,17 @@ class Invert(Bijector):
|
|||
name = (name + bijector.name) if name == 'Invert' else name
|
||||
super(Invert, self).__init__(is_constant_jacobian=bijector.is_constant_jacobian,
|
||||
is_injective=bijector.is_injective,
|
||||
dtype=bijector.dtype,
|
||||
name=name,
|
||||
dtype=bijector.dtype,
|
||||
param=param)
|
||||
self._bijector = bijector
|
||||
if hasattr(self._bijector, 'event_shape'):
|
||||
self._event_shape = self.bijector.event_shape
|
||||
else:
|
||||
self._event_shape = ()
|
||||
self._batch_shape = self.bijector.batch_shape
|
||||
self._is_scalar_batch = self.bijector.is_scalar_batch
|
||||
|
||||
@property
|
||||
def bijector(self):
|
||||
return self._bijector
|
||||
|
||||
@property
|
||||
def event_shape(self):
|
||||
return self._event_shape
|
||||
|
||||
def inverse(self, y):
|
||||
return self.bijector("forward", y)
|
||||
|
||||
|
|
|
@ -14,8 +14,7 @@
|
|||
# ============================================================================
|
||||
"""Power Bijector"""
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
from ..distribution._utils.utils import check_greater_equal_zero
|
||||
from ..distribution._utils.custom_ops import exp_generic, expm1_generic, log_generic, log1p_generic
|
||||
from .bijector import Bijector
|
||||
|
||||
|
@ -37,7 +36,7 @@ class PowerTransform(Bijector):
|
|||
ValueError: When the power is less than 0 or is not known statically.
|
||||
|
||||
Args:
|
||||
power (int or float): The scale factor. Default: 0.
|
||||
power (float, list, numpy.ndarray, Tensor): The scale factor. Default: 0.
|
||||
name (str): The name of the bijector. Default: 'PowerTransform'.
|
||||
|
||||
Examples:
|
||||
|
@ -64,10 +63,11 @@ class PowerTransform(Bijector):
|
|||
power=0,
|
||||
name='PowerTransform'):
|
||||
param = dict(locals())
|
||||
param['param_dict'] = {'power': power}
|
||||
super(PowerTransform, self).__init__(name=name, param=param)
|
||||
validator.check_value_type('power', power, [int, float], self.name)
|
||||
validator.check_number("power", power, 0, Rel.GE, self.name)
|
||||
self._power = power
|
||||
self._power = self._add_parameter(power, 'power')
|
||||
check_greater_equal_zero(self._power, 'Power')
|
||||
|
||||
self.pow = P.Pow()
|
||||
self.dtypeop = P.DType()
|
||||
self.cast = P.Cast()
|
||||
|
@ -81,13 +81,15 @@ class PowerTransform(Bijector):
|
|||
return self._power
|
||||
|
||||
def extend_repr(self):
|
||||
return f'power = {self.power}'
|
||||
if self.is_scalar_batch:
|
||||
str_info = f'power = {self.power}'
|
||||
else:
|
||||
str_info = f'batch_shape = {self.batch_shape}'
|
||||
return str_info
|
||||
|
||||
def shape_mapping(self, shape):
|
||||
return shape
|
||||
|
||||
def _forward(self, x):
|
||||
x = self._check_value(x, 'value')
|
||||
x = self._check_value_dtype(x)
|
||||
power_local = self.cast_param_by_value(x, self.power)
|
||||
if power_local == 0:
|
||||
forward_v = self.exp(x)
|
||||
|
@ -96,7 +98,7 @@ class PowerTransform(Bijector):
|
|||
return forward_v
|
||||
|
||||
def _inverse(self, y):
|
||||
y = self._check_value(y, 'value')
|
||||
y = self._check_value_dtype(y)
|
||||
power_local = self.cast_param_by_value(y, self.power)
|
||||
if power_local == 0:
|
||||
inverse_v = self.log(y)
|
||||
|
@ -116,7 +118,7 @@ class PowerTransform(Bijector):
|
|||
f'(x) = e^\frac{\log(xc + 1)}{c} * \frac{1}{xc + 1}
|
||||
\log(f'(x)) = (\frac{1}{c} - 1) * \log(xc + 1)
|
||||
"""
|
||||
x = self._check_value(x, 'value')
|
||||
x = self._check_value_dtype(x)
|
||||
power_local = self.cast_param_by_value(x, self.power)
|
||||
if power_local == 0:
|
||||
forward_log_j = x
|
||||
|
@ -136,7 +138,7 @@ class PowerTransform(Bijector):
|
|||
f'(x) = \frac{e^c\log(y)}{y}
|
||||
\log(f'(x)) = \log(\frac{e^c\log(y)}{y}) = (c-1) * \log(y)
|
||||
"""
|
||||
y = self._check_value(y, 'value')
|
||||
y = self._check_value_dtype(y)
|
||||
power_local = self.cast_param_by_value(y, self.power)
|
||||
inverse_log_j = (power_local - 1) * self.log(y)
|
||||
return inverse_log_j
|
||||
|
|
|
@ -14,8 +14,6 @@
|
|||
# ============================================================================
|
||||
"""Scalar Affine Bijector"""
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from ..distribution._utils.utils import cast_to_tensor
|
||||
from ..distribution._utils.custom_ops import log_generic
|
||||
from .bijector import Bijector
|
||||
|
||||
|
@ -30,10 +28,14 @@ class ScalarAffine(Bijector):
|
|||
where a is the scale factor and b is the shift factor.
|
||||
|
||||
Args:
|
||||
scale (float): The scale factor. Default: 1.0.
|
||||
shift (float): The shift factor. Default: 0.0.
|
||||
scale (float, list, numpy.ndarray, Tensor): The scale factor. Default: 1.0.
|
||||
shift (float, list, numpy.ndarray, Tensor): The shift factor. Default: 0.0.
|
||||
name (str): The name of the bijector. Default: 'ScalarAffine'.
|
||||
|
||||
Note:
|
||||
If `shift`, `scale` are passed in as numpy.ndarray or tensor, they have to have
|
||||
the same dtype otherwise an error will be raised.
|
||||
|
||||
Examples:
|
||||
>>> # To initialize a ScalarAffine bijector of scale 1 and shift 2.
|
||||
>>> scalaraffine = nn.probability.bijector.ScalarAffine(1, 2)
|
||||
|
@ -61,10 +63,7 @@ class ScalarAffine(Bijector):
|
|||
Constructor of ScalarAffine Bijector.
|
||||
"""
|
||||
param = dict(locals())
|
||||
validator.check_value_type(
|
||||
'scale', scale, [int, float], type(self).__name__)
|
||||
validator.check_value_type(
|
||||
'shift', shift, [int, float], type(self).__name__)
|
||||
param['param_dict'] = {'scale': scale, 'shift': shift}
|
||||
super(ScalarAffine, self).__init__(
|
||||
is_constant_jacobian=True,
|
||||
is_injective=True,
|
||||
|
@ -72,8 +71,8 @@ class ScalarAffine(Bijector):
|
|||
dtype=None,
|
||||
param=param)
|
||||
|
||||
self._scale = cast_to_tensor(scale)
|
||||
self._shift = cast_to_tensor(shift)
|
||||
self._scale = self._add_parameter(scale, 'scale')
|
||||
self._shift = self._add_parameter(shift, 'shift')
|
||||
|
||||
self.abs = P.Abs()
|
||||
self.oneslike = P.OnesLike()
|
||||
|
@ -90,17 +89,18 @@ class ScalarAffine(Bijector):
|
|||
return self._shift
|
||||
|
||||
def extend_repr(self):
|
||||
return f'scale = {self.scale}, shift = {self.shift}'
|
||||
|
||||
def shape_mapping(self, shape):
|
||||
return shape
|
||||
if self.is_scalar_batch:
|
||||
str_info = f'scale = {self.scale}, shift = {self.shift}'
|
||||
else:
|
||||
str_info = f'batch_shape = {self.batch_shape}'
|
||||
return str_info
|
||||
|
||||
def _forward(self, x):
|
||||
r"""
|
||||
.. math::
|
||||
f(x) = a * x + b
|
||||
"""
|
||||
x = self._check_value(x, 'value')
|
||||
x = self._check_value_dtype(x)
|
||||
scale_local = self.cast_param_by_value(x, self.scale)
|
||||
shift_local = self.cast_param_by_value(x, self.shift)
|
||||
forward_v = scale_local * x + shift_local * self.oneslike(x)
|
||||
|
@ -111,7 +111,7 @@ class ScalarAffine(Bijector):
|
|||
.. math::
|
||||
f(y) = \frac{y - b}{a}
|
||||
"""
|
||||
y = self._check_value(y, 'value')
|
||||
y = self._check_value_dtype(y)
|
||||
scale_local = self.cast_param_by_value(y, self.scale)
|
||||
shift_local = self.cast_param_by_value(y, self.shift)
|
||||
inverse_v = (y - shift_local) / scale_local
|
||||
|
@ -124,7 +124,7 @@ class ScalarAffine(Bijector):
|
|||
f'(x) = a
|
||||
\log(f'(x)) = \log(a)
|
||||
"""
|
||||
x = self._check_value(x, 'value')
|
||||
x = self._check_value_dtype(x)
|
||||
scale_local = self.cast_param_by_value(x, self.scale)
|
||||
forward_log_j = self.log(self.abs(scale_local))
|
||||
return forward_log_j
|
||||
|
@ -136,7 +136,7 @@ class ScalarAffine(Bijector):
|
|||
f'(x) = \frac{1.0}{a}
|
||||
\log(f'(x)) = - \log(a)
|
||||
"""
|
||||
y = self._check_value(y, 'value')
|
||||
y = self._check_value_dtype(y)
|
||||
scale_local = self.cast_param_by_value(y, self.scale)
|
||||
inverse_log_j = -1. * self.log(self.abs(scale_local))
|
||||
return inverse_log_j
|
||||
|
|
|
@ -16,8 +16,6 @@
|
|||
import numpy as np
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.nn.layer.activation import LogSigmoid
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from ..distribution._utils.utils import cast_to_tensor
|
||||
from ..distribution._utils.custom_ops import exp_generic, expm1_generic, log_generic
|
||||
from .bijector import Bijector
|
||||
|
||||
|
@ -32,7 +30,7 @@ class Softplus(Bijector):
|
|||
where k is the sharpness factor.
|
||||
|
||||
Args:
|
||||
sharpness (float): The scale factor. Default: 1.0.
|
||||
sharpness (float, list, numpy.ndarray, Tensor): The scale factor. Default: 1.0.
|
||||
name (str): The name of the Bijector. Default: 'Softplus'.
|
||||
|
||||
Examples:
|
||||
|
@ -61,10 +59,9 @@ class Softplus(Bijector):
|
|||
Constructor of Softplus Bijector.
|
||||
"""
|
||||
param = dict(locals())
|
||||
validator.check_value_type('sharpness', sharpness,
|
||||
[int, float], type(self).__name__)
|
||||
super(Softplus, self).__init__(name=name, param=param)
|
||||
self._sharpness = cast_to_tensor(sharpness)
|
||||
param['param_dict'] = {'sharpness': sharpness}
|
||||
super(Softplus, self).__init__(name=name, dtype=None, param=param)
|
||||
self._sharpness = self._add_parameter(sharpness, 'sharpness')
|
||||
|
||||
self.exp = exp_generic
|
||||
self.log = log_generic
|
||||
|
@ -118,13 +115,14 @@ class Softplus(Bijector):
|
|||
return self._sharpness
|
||||
|
||||
def extend_repr(self):
|
||||
return f'sharpness = {self.sharpness}'
|
||||
|
||||
def shape_mapping(self, shape):
|
||||
return shape
|
||||
if self.is_scalar_batch:
|
||||
str_info = f'sharpness = {self.sharpness}'
|
||||
else:
|
||||
str_info = f'batch_shape = {self.batch_shape}'
|
||||
return str_info
|
||||
|
||||
def _forward(self, x):
|
||||
x = self._check_value(x, 'value')
|
||||
x = self._check_value_dtype(x)
|
||||
sharpness_local = self.cast_param_by_value(x, self.sharpness)
|
||||
scaled_value = sharpness_local * x
|
||||
forward_v = self.softplus(scaled_value) / sharpness_local
|
||||
|
@ -136,7 +134,7 @@ class Softplus(Bijector):
|
|||
f(x) = \frac{\log(1 + e^{kx}))}{k}
|
||||
f^{-1}(y) = \frac{\log(e^{ky} - 1)}{k}
|
||||
"""
|
||||
y = self._check_value(y, 'value')
|
||||
y = self._check_value_dtype(y)
|
||||
sharpness_local = self.cast_param_by_value(y, self.sharpness)
|
||||
scaled_value = sharpness_local * y
|
||||
inverse_v = self.inverse_softplus(scaled_value) / sharpness_local
|
||||
|
@ -149,7 +147,7 @@ class Softplus(Bijector):
|
|||
f'(x) = \frac{e^{kx}}{ 1 + e^{kx}}
|
||||
\log(f'(x)) = kx - \log(1 + e^{kx}) = kx - f(kx)
|
||||
"""
|
||||
x = self._check_value(x, 'value')
|
||||
x = self._check_value_dtype(x)
|
||||
sharpness_local = self.cast_param_by_value(x, self.sharpness)
|
||||
scaled_value = sharpness_local * x
|
||||
forward_log_j = self.log_sigmoid(scaled_value)
|
||||
|
@ -162,7 +160,7 @@ class Softplus(Bijector):
|
|||
f'(y) = \frac{e^{ky}}{e^{ky} - 1}
|
||||
\log(f'(y)) = ky - \log(e^{ky} - 1) = ky - f(ky)
|
||||
"""
|
||||
y = self._check_value(y, 'value')
|
||||
y = self._check_value_dtype(y)
|
||||
sharpness_local = self.cast_param_by_value(y, self.sharpness)
|
||||
scaled_value = sharpness_local * y
|
||||
inverse_log_j = scaled_value - self.inverse_softplus(scaled_value)
|
||||
|
|
|
@ -229,6 +229,11 @@ def raise_not_implemented_util(func_name, obj, *args, **kwargs):
|
|||
raise NotImplementedError(
|
||||
f"{func_name} is not implemented for {obj} distribution.")
|
||||
|
||||
@constexpr
|
||||
def raise_type_error(name, cur_type, required_type):
|
||||
raise TypeError(
|
||||
f"For {name} , the type should be or be subclass of {required_type}, but got {cur_type}")
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_distribution_name(name, expected_name):
|
||||
|
@ -304,7 +309,7 @@ def set_param_type(args, hint_type):
|
|||
TypeError: if tensors in args are not the same dtype.
|
||||
"""
|
||||
int_type = mstype.int_type + mstype.uint_type
|
||||
if hint_type in int_type:
|
||||
if hint_type in int_type or hint_type is None:
|
||||
hint_type = mstype.float32
|
||||
common_dtype = None
|
||||
for name, arg in args.items():
|
||||
|
|
|
@ -72,13 +72,12 @@ class Distribution(Cell):
|
|||
if not(k == 'self' or k.startswith('_')):
|
||||
self._parameters[k] = param[k]
|
||||
|
||||
# some attributes
|
||||
if 'distribution' in self.parameters.keys():
|
||||
self.parameter_type = self.parameters['distribution'].parameter_type
|
||||
else:
|
||||
# if not a transformed distribution, set the following attribute
|
||||
if 'distribution' not in self.parameters.keys():
|
||||
self.parameter_type = set_param_type(self.parameters['param_dict'], dtype)
|
||||
self._broadcast_shape = self._calc_broadcast_shape()
|
||||
self._is_scalar_batch = self._check_is_scalar_batch()
|
||||
self._batch_shape = self._calc_batch_shape()
|
||||
self._is_scalar_batch = self._check_is_scalar_batch()
|
||||
self._broadcast_shape = self._batch_shape
|
||||
|
||||
# set the function to call according to the derived class's attributes
|
||||
self._set_prob()
|
||||
|
@ -128,6 +127,10 @@ class Distribution(Cell):
|
|||
def is_scalar_batch(self):
|
||||
return self._is_scalar_batch
|
||||
|
||||
@property
|
||||
def batch_shape(self):
|
||||
return self._batch_shape
|
||||
|
||||
@property
|
||||
def broadcast_shape(self):
|
||||
return self._broadcast_shape
|
||||
|
@ -208,8 +211,6 @@ class Distribution(Cell):
|
|||
"""
|
||||
Check if the parameters used during initialization are scalars.
|
||||
"""
|
||||
if 'distribution' in self.parameters.keys():
|
||||
return self.parameters['distribution'].is_scalar_batch
|
||||
param_dict = self.parameters['param_dict']
|
||||
for value in param_dict.values():
|
||||
if value is None:
|
||||
|
@ -218,12 +219,10 @@ class Distribution(Cell):
|
|||
return False
|
||||
return True
|
||||
|
||||
def _calc_broadcast_shape(self):
|
||||
def _calc_batch_shape(self):
|
||||
"""
|
||||
Calculate the broadcast shape of the parameters used during initialization.
|
||||
"""
|
||||
if 'distribution' in self.parameters.keys():
|
||||
return self.parameters['distribution'].broadcast_shape
|
||||
param_dict = self.parameters['param_dict']
|
||||
broadcast_shape_tensor = None
|
||||
for value in param_dict.values():
|
||||
|
@ -362,14 +361,14 @@ class Distribution(Cell):
|
|||
"""
|
||||
return self._get_dist_args(*args, **kwargs)
|
||||
|
||||
def _get_dist_type(self, *args, **kwargs):
|
||||
return raise_not_implemented_util('get_dist_type', self.name, *args, **kwargs)
|
||||
def _get_dist_type(self):
|
||||
return raise_not_implemented_util('get_dist_type', self.name)
|
||||
|
||||
def get_dist_type(self, *args, **kwargs):
|
||||
def get_dist_type(self):
|
||||
"""
|
||||
Return the type of the distribution.
|
||||
"""
|
||||
return self._get_dist_type(*args, **kwargs)
|
||||
return self._get_dist_type()
|
||||
|
||||
def _raise_not_implemented_error(self, func_name):
|
||||
name = self.name
|
||||
|
@ -751,5 +750,5 @@ class Distribution(Cell):
|
|||
if name == 'get_dist_args':
|
||||
return self._get_dist_args(*args, **kwargs)
|
||||
if name == 'get_dist_type':
|
||||
return self._get_dist_type(*args, **kwargs)
|
||||
return self._get_dist_type()
|
||||
return raise_not_implemented_util(name, self.name, *args, **kwargs)
|
||||
|
|
|
@ -103,17 +103,12 @@ class Gumbel(TransformedDistribution):
|
|||
"""
|
||||
valid_dtype = mstype.float_type
|
||||
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
|
||||
gumbel_cdf = msb.GumbelCDF(loc, scale, dtype)
|
||||
gumbel_cdf = msb.GumbelCDF(loc, scale)
|
||||
super(Gumbel, self).__init__(
|
||||
distribution=msd.Uniform(0.0, 1.0, dtype=dtype),
|
||||
bijector=msb.Invert(gumbel_cdf),
|
||||
seed=seed, name=name)
|
||||
|
||||
self.parameter_type = gumbel_cdf.parameter_type
|
||||
self._broadcast_shape = gumbel_cdf.event_shape
|
||||
if self._broadcast_shape != ():
|
||||
self._is_scalar_batch = False
|
||||
|
||||
# overwrite default_parameters and parameter_names
|
||||
self._reset_parameters()
|
||||
self._loc = self._add_parameter(loc, 'loc')
|
||||
|
@ -202,6 +197,7 @@ class Gumbel(TransformedDistribution):
|
|||
where z = \frac{x - loc}{scale}
|
||||
"""
|
||||
value = self._check_value(value, 'value')
|
||||
value = self.cast(value, self.dtype)
|
||||
z = (value - self.loc) / self.scale
|
||||
return -(z + self.exp(-z)) - self.log(self.scale)
|
||||
|
||||
|
@ -210,6 +206,8 @@ class Gumbel(TransformedDistribution):
|
|||
.. math::
|
||||
cdf_pdf(X) = \exp(-\exp(-\frac{x - loc}{scale})
|
||||
"""
|
||||
value = self._check_value(value, 'value')
|
||||
value = self.cast(value, self.dtype)
|
||||
return self._gumbel_bijector("forward", value)
|
||||
|
||||
def _cross_entropy(self, dist, loc_b, scale_b):
|
||||
|
@ -251,12 +249,14 @@ class Gumbel(TransformedDistribution):
|
|||
self.expm1((loc_b - self.loc) / scale_b + self.lgamma(self.scale / scale_b + 1.))
|
||||
|
||||
def _sample(self, shape=()):
|
||||
shape = self.checktuple(shape, 'shape')
|
||||
origin_shape = shape + self._broadcast_shape
|
||||
if origin_shape == ():
|
||||
sample_shape = (1,)
|
||||
else:
|
||||
sample_shape = origin_shape
|
||||
org_sample = self.distribution("sample", sample_shape)
|
||||
org_sample = self.cast(org_sample, self.dtype)
|
||||
value = self.bijector("forward", org_sample)
|
||||
if origin_shape == ():
|
||||
value = self.squeeze(value)
|
||||
|
|
|
@ -137,6 +137,11 @@ class LogNormal(msd.TransformedDistribution):
|
|||
bijector=msb.Exp(),
|
||||
seed=seed, name=name)
|
||||
|
||||
# overwrite default_parameters and parameter_names
|
||||
self._reset_parameters()
|
||||
self._loc = self._add_parameter(loc, 'loc')
|
||||
self._scale = self._add_parameter(scale, 'scale')
|
||||
|
||||
self.log_2pi = np.log(2 * np.pi)
|
||||
|
||||
#ops needed for the class
|
||||
|
@ -154,12 +159,12 @@ class LogNormal(msd.TransformedDistribution):
|
|||
@property
|
||||
def loc(self):
|
||||
"""Distribution parameter for the pre-transformed mean."""
|
||||
return self.distribution("mean")
|
||||
return self._loc
|
||||
|
||||
@property
|
||||
def scale(self):
|
||||
"""Distribution parameter for the pre-transformed standard deviation."""
|
||||
return self.distribution("sd")
|
||||
return self._scale
|
||||
|
||||
def _get_dist_type(self):
|
||||
return "LogNormal"
|
||||
|
@ -168,18 +173,18 @@ class LogNormal(msd.TransformedDistribution):
|
|||
if loc is not None:
|
||||
self.checktensor(loc, 'loc')
|
||||
else:
|
||||
loc = self.distribution("mean")
|
||||
loc = self.loc
|
||||
if scale is not None:
|
||||
self.checktensor(scale, 'scale')
|
||||
else:
|
||||
scale = self.distribution("sd")
|
||||
scale = self.scale
|
||||
return loc, scale
|
||||
|
||||
def extend_repr(self):
|
||||
if self.is_scalar_batch:
|
||||
s = f'loc = {self._mean_value}, scale = {self._sd_value}'
|
||||
s = f'loc = {self.loc}, scale = {self.scale}'
|
||||
else:
|
||||
s = f'batch_shape = {self._broadcast_shape}'
|
||||
s = f'batch_shape = {self.broadcast_shape}'
|
||||
return s
|
||||
|
||||
def _mean(self, loc=None, scale=None):
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
import numpy as np
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
from .distribution import Distribution
|
||||
from ._utils.utils import raise_not_impl_error
|
||||
|
@ -30,7 +31,7 @@ class TransformedDistribution(Distribution):
|
|||
|
||||
Args:
|
||||
bijector (Bijector): The transformation to perform.
|
||||
distribution (Distribution): The original distribution.
|
||||
distribution (Distribution): The original distribution. Must has dtype of mindspore.float_type.
|
||||
seed (int): The seed is used in sampling. The global seed is used if it is None. Default:None.
|
||||
If this seed is given when a TransformedDistribution object is initialised, the object's sampling function
|
||||
will use this seed; elsewise, the underlying distribution's seed will be used.
|
||||
|
@ -40,6 +41,12 @@ class TransformedDistribution(Distribution):
|
|||
The arguments used to initialize the original distribution cannot be None.
|
||||
For example, mynormal = nn.Normal(dtype=dtyple.float32) cannot be used to initialized a
|
||||
TransformedDistribution since `mean` and `sd` are not specified.
|
||||
`batch_shape` is the batch_shape of the original distribution.
|
||||
`broadcast_shape` is the broadcast shape between the original distribution and bijector.
|
||||
`is_scalar_batch` is only true if both the original distribution and the bijector are scalar batches.
|
||||
`default_parameters`, `parameter_names` and `parameter_type` are set to be consistent with the original
|
||||
distribution. Derived class can overwrite `default_parameters` and `parameter_names` by calling
|
||||
`reset_parameters` followed by `add_parameter`.
|
||||
|
||||
Examples:
|
||||
>>> # To initialize a transformed distribution, e.g. a lognormal distribution,
|
||||
|
@ -75,28 +82,34 @@ class TransformedDistribution(Distribution):
|
|||
[nn.probability.bijector.Bijector], type(self).__name__)
|
||||
validator.check_value_type('distribution', distribution,
|
||||
[Distribution], type(self).__name__)
|
||||
validator.check_type_name("dtype", distribution.dtype, mstype.float_type, type(self).__name__)
|
||||
super(TransformedDistribution, self).__init__(seed, distribution.dtype, name, param)
|
||||
|
||||
self._bijector = bijector
|
||||
self._distribution = distribution
|
||||
self._is_linear_transformation = bijector.is_constant_jacobian
|
||||
self.default_parameters = distribution.default_parameters
|
||||
self.parameter_names = distribution.parameter_names
|
||||
|
||||
# set attributes
|
||||
self._is_linear_transformation = self.bijector.is_constant_jacobian
|
||||
self._dtype = self.distribution.dtype
|
||||
self._is_scalar_batch = self.distribution.is_scalar_batch and self.bijector.is_scalar_batch
|
||||
self._batch_shape = self.distribution.batch_shape
|
||||
|
||||
self.default_parameters = self.distribution.default_parameters
|
||||
self.parameter_names = self.distribution.parameter_names
|
||||
# by default, set the parameter_type to be the distribution's parameter_type
|
||||
self.parameter_type = self.distribution.parameter_type
|
||||
|
||||
self.exp = exp_generic
|
||||
self.log = log_generic
|
||||
self.isnan = P.IsNan()
|
||||
self.cast_base = P.Cast()
|
||||
self.equal_base = P.Equal()
|
||||
self.select_base = P.Select()
|
||||
self.fill = P.Fill()
|
||||
self.fill_base = P.Fill()
|
||||
|
||||
# broadcast bijector batch_shape and distribution batch_shape
|
||||
self._broadcast_shape = self._broadcast_bijector_dist()
|
||||
|
||||
# check if batch shape of the distribution and event shape is broadcastable
|
||||
if hasattr(self.bijector, 'event_shape'):
|
||||
event_shape_tensor = self.fill(self.dtype, self.bijector.event_shape, 0.0)
|
||||
broadcast_shape_tensor = self.fill(self.dtype, self.broadcast_shape, 0.0)
|
||||
self._batch_event = (event_shape_tensor + broadcast_shape_tensor).shape
|
||||
else:
|
||||
self._batch_event = self.broadcast_shape
|
||||
|
||||
@property
|
||||
def bijector(self):
|
||||
|
@ -108,12 +121,22 @@ class TransformedDistribution(Distribution):
|
|||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.distribution.dtype
|
||||
return self._dtype
|
||||
|
||||
@property
|
||||
def is_linear_transformation(self):
|
||||
return self._is_linear_transformation
|
||||
|
||||
def _broadcast_bijector_dist(self):
|
||||
"""
|
||||
check if the batch shape of base distribution and the bijector is broadcastable.
|
||||
"""
|
||||
if self.batch_shape is None or self.bijector.batch_shape is None:
|
||||
return None
|
||||
bijector_shape_tensor = self.fill_base(self.dtype, self.bijector.batch_shape, 0.0)
|
||||
dist_shape_tensor = self.fill_base(self.dtype, self.batch_shape, 0.0)
|
||||
return (bijector_shape_tensor + dist_shape_tensor).shape
|
||||
|
||||
def _cdf(self, value, *args, **kwargs):
|
||||
r"""
|
||||
.. math::
|
||||
|
|
|
@ -0,0 +1,191 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test cases for exp"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.nn.probability.bijector as msb
|
||||
from mindspore import Tensor
|
||||
from mindspore import dtype
|
||||
|
||||
class MyBijector(msb.Bijector):
|
||||
"""
|
||||
Customized bijector class with dtype not specified.
|
||||
"""
|
||||
def __init__(self, param1, param2):
|
||||
param = dict(locals())
|
||||
param['param_dict'] = {'param1': param1, 'param2': param2}
|
||||
super(MyBijector, self).__init__(name='MyBijector', dtype=None, param=param)
|
||||
|
||||
self._param1 = self._add_parameter(param1, 'param1')
|
||||
self._param2 = self._add_parameter(param2, 'param2')
|
||||
|
||||
@property
|
||||
def param1(self):
|
||||
return self._param1
|
||||
|
||||
@property
|
||||
def param2(self):
|
||||
return self._param2
|
||||
|
||||
def _forward(self, value):
|
||||
value = self._check_value_dtype(value)
|
||||
param1_local = self.cast_param_by_value(value, self.param1)
|
||||
param2_local = self.cast_param_by_value(value, self.param2)
|
||||
return value * param1_local + param2_local
|
||||
|
||||
class MySecondBijector(msb.Bijector):
|
||||
"""
|
||||
Customized bijector class with dtype specified.
|
||||
"""
|
||||
def __init__(self, param1, param2):
|
||||
param = dict(locals())
|
||||
param['param_dict'] = {'param1': param1, 'param2': param2}
|
||||
super(MySecondBijector, self).__init__(name='MySecondBijector', dtype=dtype.float32, param=param)
|
||||
|
||||
self._param1 = self._add_parameter(param1, 'param1')
|
||||
self._param2 = self._add_parameter(param2, 'param2')
|
||||
|
||||
@property
|
||||
def param1(self):
|
||||
return self._param1
|
||||
|
||||
@property
|
||||
def param2(self):
|
||||
return self._param2
|
||||
|
||||
def _forward(self, value):
|
||||
value = self._check_value_dtype(value)
|
||||
param1_local = self.cast_param_by_value(value, self.param1)
|
||||
param2_local = self.cast_param_by_value(value, self.param2)
|
||||
return value * param1_local + param2_local
|
||||
|
||||
def test_arguments_same_type():
|
||||
"""
|
||||
Test bijector initializations.
|
||||
"""
|
||||
param1_1 = np.array(1.0).astype(np.float16)
|
||||
param2_1 = np.array(2.0).astype(np.float32)
|
||||
with pytest.raises(TypeError):
|
||||
MyBijector(param1_1, param2_1)
|
||||
param1_2 = Tensor(1.0, dtype=dtype.float16)
|
||||
param2_2 = Tensor(2.0, dtype=dtype.float32)
|
||||
with pytest.raises(TypeError):
|
||||
MyBijector(param1_2, param2_2)
|
||||
with pytest.raises(TypeError):
|
||||
MyBijector(True, param2_2)
|
||||
with pytest.raises(TypeError):
|
||||
MyBijector(None, param2_2)
|
||||
param1_3 = Tensor(1.0, dtype=dtype.float32)
|
||||
param2_3 = Tensor(2.0, dtype=dtype.float32)
|
||||
bijector = MyBijector(param1_3, param2_3)
|
||||
assert isinstance(bijector, msb.Bijector)
|
||||
param1_4 = np.array([1.0, 2.0]).astype(np.float16)
|
||||
param2_4 = np.array([1.0, 2.0]).astype(np.float16)
|
||||
bijector = MyBijector(param1_4, param2_4)
|
||||
assert isinstance(bijector, msb.Bijector)
|
||||
bijector = MyBijector(1.0, 2.0)
|
||||
assert isinstance(bijector, msb.Bijector)
|
||||
|
||||
def test_arguments_with_dtype_specified():
|
||||
"""
|
||||
Customized bijector class with dtype not specified.
|
||||
"""
|
||||
param1_1 = np.array(1.0).astype(np.float16)
|
||||
param2_1 = np.array(2.0).astype(np.float16)
|
||||
with pytest.raises(TypeError):
|
||||
MySecondBijector(param1_1, param2_1)
|
||||
param1_2 = Tensor(1.0, dtype=dtype.float16)
|
||||
param2_2 = Tensor(2.0, dtype=dtype.float32)
|
||||
with pytest.raises(TypeError):
|
||||
MySecondBijector(param1_2, param2_2)
|
||||
with pytest.raises(TypeError):
|
||||
MySecondBijector(True, param2_2)
|
||||
with pytest.raises(TypeError):
|
||||
MySecondBijector(None, param2_2)
|
||||
param1_3 = Tensor(1.0, dtype=dtype.float32)
|
||||
param2_3 = Tensor(2.0, dtype=dtype.float32)
|
||||
bijector = MyBijector(param1_3, param2_3)
|
||||
assert isinstance(bijector, msb.Bijector)
|
||||
param1_4 = np.array(2.0).astype(np.float32)
|
||||
param2_4 = np.array(1.0).astype(np.float32)
|
||||
bijector = MyBijector(param1_4, param2_4)
|
||||
assert isinstance(bijector, msb.Bijector)
|
||||
|
||||
class Net1(nn.Cell):
|
||||
"""
|
||||
Test input value when bijector's dtype is not specified.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(Net1, self).__init__()
|
||||
self.bijector = MyBijector(np.array(1.0).astype(np.float32), np.array(2.0).astype(np.float32))
|
||||
|
||||
def construct(self, value):
|
||||
return self.bijector.forward(value)
|
||||
|
||||
class Net2(nn.Cell):
|
||||
"""
|
||||
Test input value when bijector's dtype is specified.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(Net2, self).__init__()
|
||||
self.bijector = MySecondBijector(np.array(1.0).astype(np.float32), np.array(2.0).astype(np.float32))
|
||||
|
||||
def construct(self, value):
|
||||
return self.bijector.forward(value)
|
||||
|
||||
def test_input_value():
|
||||
"""
|
||||
Test validity of input value.
|
||||
"""
|
||||
net = Net1()
|
||||
value = None
|
||||
with pytest.raises(TypeError):
|
||||
ans = net(value)
|
||||
value = 1.0
|
||||
with pytest.raises(TypeError):
|
||||
ans = net(value)
|
||||
value = Tensor(1.0, dtype=dtype.int32)
|
||||
with pytest.raises(TypeError):
|
||||
ans = net(value)
|
||||
value = Tensor(1.0, dtype=dtype.float32)
|
||||
ans = net(value)
|
||||
assert ans.dtype == dtype.float32
|
||||
value = Tensor(1.0, dtype=dtype.float16)
|
||||
ans = net(value)
|
||||
assert ans.dtype == dtype.float16
|
||||
|
||||
def test_input_value2():
|
||||
"""
|
||||
Test validity of input value.
|
||||
"""
|
||||
net = Net2()
|
||||
value = None
|
||||
with pytest.raises(TypeError):
|
||||
ans = net(value)
|
||||
value = 1.0
|
||||
with pytest.raises(TypeError):
|
||||
ans = net(value)
|
||||
value = Tensor(1.0, dtype=dtype.int32)
|
||||
with pytest.raises(TypeError):
|
||||
ans = net(value)
|
||||
value = Tensor(1.0, dtype=dtype.float16)
|
||||
with pytest.raises(TypeError):
|
||||
ans = net(value)
|
||||
value = Tensor(1.0, dtype=dtype.float32)
|
||||
ans = net(value)
|
||||
assert ans.dtype == dtype.float32
|
||||
|
|
@ -1,3 +1,7 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
|
Loading…
Reference in New Issue