Added Bijector TransformDistribution base classes and two instances: power and exp bijectors

This commit is contained in:
peixu_ren 2020-07-28 21:15:57 -04:00
parent 64b0cb4f95
commit b9c8c0b439
12 changed files with 780 additions and 1 deletions

View File

@ -35,5 +35,4 @@ __all__.extend(metrics.__all__)
__all__.extend(wrap.__all__)
__all__.extend(sparse.__all__)
__all__.sort()

View File

@ -18,4 +18,5 @@ Probability.
The high-level components used to construct the probabilistic network.
"""
from . import bijector
from . import distribution

View File

@ -0,0 +1,27 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Bijector.
The high-level components(Bijectors) used to construct the probabilistic network.
"""
from .bijector import Bijector
from .power_transform import PowerTransform
from .exp import Exp
__all__ = ['Bijector',
'PowerTransform',
'Exp']

View File

@ -0,0 +1,130 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Bijector"""
from mindspore.nn.cell import Cell
from ..distribution import Distribution
from ..distribution import TransformedDistribution
class Bijector(Cell):
"""
Bijecotr class.
Args:
is_constant_jacobian (bool): if the bijector has constant derivative. Default: False.
is_injective (bool): if the bijector is an one-to-one mapping. Default: True.
name (str): name of the bijector. Default: None.
dtype (mstype): type of the distribution the bijector can operate on. Default: None.
param (dict): parameters used to initialize the bijector. Default: None.
"""
def __init__(self,
is_constant_jacobian=False,
is_injective=True,
name=None,
dtype=None,
param=None):
"""
Constructor of bijector class.
"""
super(Bijector, self).__init__()
self._name = name
self._dtype = dtype
self._parameters = {}
# parsing parameters
for k in param.keys():
if not(k == 'self' or k.startswith('_')):
self._parameters[k] = param[k]
self._is_constant_jacobian = is_constant_jacobian
self._is_injective = is_injective
@property
def name(self):
return self._name
@property
def dtype(self):
return self._dtype
@property
def parameters(self):
return self._parameters
@property
def is_constant_jacobian(self):
return self._is_constant_jacobian
@property
def is_injective(self):
return self._is_injective
def forward(self, *args):
"""
Forward transformation: transform the input value to another distribution.
"""
return self._forward(*args)
def inverse(self, *args):
"""
Inverse transformation: transform the input value back to the original distribution.
"""
return self._inverse(*args)
def forward_log_jacobian(self, *args):
"""
Logarithm of the derivative of forward transformation.
"""
return self._forward_log_jacobian(*args)
def inverse_log_jacobian(self, *args):
"""
Logarithm of the derivative of forward transformation.
"""
return self._inverse_log_jacobian(*args)
def __call__(self, *args):
"""
Call Bijector directly.
This __call__ may go into two directions:
If args[0] is a distribution instance, the call will generate a new distribution derived from
the input distribution.
Otherwise, input[0] should be the name of a bijector function, e.g. "forward", then this call will
go in the construct and invoke the correstpoding bijector function.
Args:
*args: args[0] shall be either a distribution or the name of a bijector function.
"""
if isinstance(args[0], Distribution):
return TransformedDistribution(self, args[0])
return super(Bijector, self).__call__(*args)
def construct(self, name, *args):
"""
Override construct in Cell.
Args:
*inputs: inputs[0] is always the name of a function.
Notes:
Always raise RuntimeError as Distribution should not be called directly.
"""
if name == 'forward':
return self.forward(*args)
if name == 'inverse':
return self.inverse(*args)
if name == 'forward_log_jacobian':
return self.forward_log_jacobian(*args)
if name == 'inverse_log_jacobian':
return self.inverse_log_jacobian(*args)
return None

View File

@ -0,0 +1,44 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Power Bijector"""
from .power_transform import PowerTransform
class Exp(PowerTransform):
r"""
Exponential Bijector.
This Bijector performs the operation: Y = exp(x).
Examples:
>>> # To initialize a Exp bijector
>>> import mindspore.nn.probability.bijector as msb
>>> n = msb.Exp()
>>>
>>> # To use Exp distribution in a network
>>> class net(Cell):
>>> def __init__(self):
>>> super(net, self).__init__():
>>> self.e1 = msb.Exp()
>>>
>>> def construct(self, value):
>>>
>>> # Similar calls can be made to other probability functions
>>> # by replacing 'forward' with the name of the function
>>> ans1 = self.e1.forward(value)
>>> ans2 = self.e1.backward(value)
"""
def __init__(self,
name='Exp'):
param = dict(locals())
super(Exp, self).__init__(name=name, param=param)

View File

@ -0,0 +1,124 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Power Bijector"""
from mindspore.ops import operations as P
from mindspore._checkparam import Validator as validator
from .bijector import Bijector
class PowerTransform(Bijector):
r"""
Power Bijector.
This Bijector performs the operation: Y = g(X) = (1 + X * c)^(1 / c), X >= -1 / c, where c is power.
The power transform maps inputs from `[-1/c, inf]` to `[0, inf]`.
This bijector is equivalent to the `Exp` bijector when `c=0`
Args:
power (int or float): scale factor. Default: 0.
Examples:
>>> # To initialize a PowerTransform bijector of power 0.5
>>> import mindspore.nn.probability.bijector as msb
>>> n = msb.PowerTransform(0.5)
>>>
>>> # To use PowerTransform distribution in a network
>>> class net(Cell):
>>> def __init__(self):
>>> super(net, self).__init__():
>>> self.p1 = msb.PowerTransform(0.5)
>>>
>>> def construct(self, value):
>>>
>>> # Similar calls can be made to other probability functions
>>> # by replacing 'forward' with the name of the function
>>> ans = self.p1.forward(, value)
"""
def __init__(self,
power=0,
name='PowerTransform',
param=None):
param = dict(locals()) if param is None else param
super(PowerTransform, self).__init__(name=name, param=param)
validator.check_value_type('power', power, [int, float], self.name)
self._power = power
self.pow = P.Pow()
self.exp = P.Exp()
self.log = P.Log()
self.log1p = self._log1p_by_step
self.expm1 = self._expm1_by_step
def _log1p_by_step(self, x):
"""
Log1p ops on GPU device or when device_target == GPU.
"""
return self.log(x + 1.0)
def _expm1_by_step(self, x):
"""
Expm1 ops on GPU device or when device_target == GPU.
"""
return self.exp(x) - 1.0
@property
def power(self):
return self._power
def extend_repr(self):
str_info = f'power = {self.power}'
return str_info
def shape_mapping(self, shape):
return shape
def _forward(self, x):
if self.power == 0:
return self.exp(x)
return self.exp(self.log1p(x * self.power) / self.power)
def _inverse(self, y):
if self.power == 0:
return self.log(y)
return self.expm1(self.log(y) * self.power) / self.power
def _forward_log_jacobian(self, x):
r"""
.. math:
if c == 0:
f(x) = e^x
f'(x) = e^x
\log(f'(x)) = \log(e^x) = x
else:
f(x) = e^\frac{\log(xc + 1)}{c}
f'(x) = e^\frac{\log(xc + 1)}{c} * \frac{1}{xc + 1}
\log(f'(x)) = (\frac{1}{c} - 1) * \log(xc + 1)
"""
if self.power == 0:
return x
return (1. / self.power - 1) * self.log1p(x * self.power)
def _inverse_log_jacobian(self, y):
r"""
.. math:
if c == 0:
f(x) = \log(x)
f'(x) = \frac{1}{x}
\log(f'(x)) = \log(\frac{1}{x}) = -\log(x)
else:
f(x) = \frac{e^\log(y)*c + 1}{c}
f'(x) = \frac{e^c\log(y)}{y}
\log(f'(x)) = \log(\frac{e^c\log(y)}{y}) = (c-1) * \log(y)
"""
return (self.power - 1) * self.log(y)

View File

@ -19,6 +19,7 @@ The high-level components(Distributions) used to construct the probabilistic net
"""
from .distribution import Distribution
from .transformed_distribution import TransformedDistribution
from .normal import Normal
from .bernoulli import Bernoulli
from .exponential import Exponential
@ -26,6 +27,7 @@ from .uniform import Uniform
from .geometric import Geometric
__all__ = ['Distribution',
'TransformedDistribution',
'Normal',
'Bernoulli',
'Exponential',

View File

@ -0,0 +1,94 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Transformed Distribution"""
from mindspore.ops import operations as P
from .distribution import Distribution
class TransformedDistribution(Distribution):
"""
Transformed Distribution.
This class contains a bijector and a distribution and transforms the original distribution
to a new distribution through the operation defined by the bijector.
Args:
bijector (Bijector): transformation to perform.
distribution (Distribution): The original distribution.
name (str): name of the transformed distribution. Default: transformed_distribution.
Note:
The arguments used to initialize the original distribution cannot be None.
For example, mynormal = nn.Normal(dtype=dtyple.float32) cannot be used to initialized a
TransformedDistribution since mean and sd are not specified.
"""
def __init__(self,
bijector,
distribution,
name="transformed_distribution"):
"""
Constructor of transformed_distribution class.
"""
param = dict(locals())
super(TransformedDistribution, self).__init__(distribution.dtype, name, param)
self._bijector = bijector
self._distribution = distribution
self._is_linear_transformation = bijector.is_constant_jacobian
self.exp = P.Exp()
@property
def bijector(self):
return self._bijector
@property
def distribution(self):
return self._distribution
@property
def is_linear_transformation(self):
return self._is_linear_transformation
def _cdf(self, value):
r"""
.. math::
Y = g(X)
P(Y <= a) = P(X <= g^{-1}(a))
"""
inverse_value = self.bijector.inverse(value)
return self.distribution.cdf(inverse_value)
def _log_prob(self, value):
r"""
.. math::
Y = g(X)
Py(a) = Px(g^{-1}(a)) * (g^{-1})'(a)
\log(Py(a)) = \log(Px(g^{-1}(a))) + \log((g^{-1})'(a))
"""
inverse_value = self.bijector.inverse(value)
unadjust_prob = self.distribution.log_prob(inverse_value)
log_jacobian = self.bijector.inverse_log_jacobian(value)
return unadjust_prob + log_jacobian
def _prob(self, value):
return self.exp(self._log_prob(value))
def _sample(self, shape):
org_sample = self.distribution.sample(shape)
return self.bijector.forward(org_sample)
def _mean(self):
"""
Note:
This function maybe overridden by derived class.
"""
return self.bijector.forward(self.distribution.mean())

View File

@ -0,0 +1,105 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test cases for exp"""
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
import mindspore.nn.probability.bijector as msb
from mindspore import Tensor
from mindspore import dtype
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell):
"""
Test class: forward pass of bijector.
"""
def __init__(self):
super(Net, self).__init__()
self.bijector = msb.Exp()
def construct(self, x_):
forward = self.bijector.forward(x_)
return forward
def test_forward():
x = np.array([2.0, 3.0, 4.0, 5.0], dtype=np.float32)
tx = Tensor(x, dtype=dtype.float32)
forward = Net()
ans = forward(tx)
expected = np.exp(x)
tol = 1e-5
assert (np.abs(ans.asnumpy() - expected) < tol).all()
class Net1(nn.Cell):
"""
Test class: inverse pass of bijector.
"""
def __init__(self):
super(Net1, self).__init__()
self.bijector = msb.Exp()
def construct(self, y_):
inverse = self.bijector.inverse(y_)
return inverse
def test_inverse():
y = np.array([2.0, 3.0, 4.0, 5.0], dtype=np.float32)
ty = Tensor(y, dtype=dtype.float32)
inverse = Net1()
ans = inverse(ty)
expected = np.log(y)
tol = 1e-6
assert (np.abs(ans.asnumpy() - expected) < tol).all()
class Net2(nn.Cell):
"""
Test class: Forward Jacobian.
"""
def __init__(self):
super(Net2, self).__init__()
self.bijector = msb.Exp()
def construct(self, x_):
return self.bijector.forward_log_jacobian(x_)
def test_forward_jacobian():
x = np.array([2.0, 3.0, 4.0, 5.0], dtype=np.float32)
tx = Tensor(x, dtype=dtype.float32)
forward_jacobian = Net2()
ans = forward_jacobian(tx)
expected = x
tol = 1e-6
assert (np.abs(ans.asnumpy() - expected) < tol).all()
class Net3(nn.Cell):
"""
Test class: Backward Jacobian.
"""
def __init__(self):
super(Net3, self).__init__()
self.bijector = msb.Exp()
def construct(self, y_):
return self.bijector.inverse_log_jacobian(y_)
def test_inverse_jacobian():
y = np.array([2.0, 3.0, 4.0, 5.0], dtype=np.float32)
ty = Tensor(y, dtype=dtype.float32)
inverse_jacobian = Net3()
ans = inverse_jacobian(ty)
expected = -np.log(y)
tol = 1e-6
assert (np.abs(ans.asnumpy() - expected) < tol).all()

View File

@ -0,0 +1,109 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test cases for powertransform"""
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
import mindspore.nn.probability.bijector as msb
from mindspore import Tensor
from mindspore import dtype
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell):
"""
Test class: forward pass of bijector.
"""
def __init__(self, power):
super(Net, self).__init__()
self.bijector = msb.PowerTransform(power=power)
def construct(self, x_):
forward = self.bijector.forward(x_)
return forward
def test_forward():
power = 2
x = np.array([2.0, 3.0, 4.0, 5.0], dtype=np.float32)
tx = Tensor(x, dtype=dtype.float32)
forward = Net(power=power)
ans = forward(tx)
expected = np.exp(np.log1p(x * power) / power)
tol = 1e-6
assert (np.abs(ans.asnumpy() - expected) < tol).all()
class Net1(nn.Cell):
"""
Test class: inverse pass of bijector.
"""
def __init__(self, power):
super(Net1, self).__init__()
self.bijector = msb.PowerTransform(power=power)
def construct(self, y_):
inverse = self.bijector.inverse(y_)
return inverse
def test_inverse():
power = 2
y = np.array([2.0, 3.0, 4.0, 5.0], dtype=np.float32)
ty = Tensor(y, dtype=dtype.float32)
inverse = Net1(power=power)
ans = inverse(ty)
expected = np.expm1(np.log(y) * power) / power
tol = 1e-6
assert (np.abs(ans.asnumpy() - expected) < tol).all()
class Net2(nn.Cell):
"""
Test class: Forward Jacobian.
"""
def __init__(self, power):
super(Net2, self).__init__()
self.bijector = msb.PowerTransform(power=power)
def construct(self, x_):
return self.bijector.forward_log_jacobian(x_)
def test_forward_jacobian():
power = 2
x = np.array([2.0, 3.0, 4.0, 5.0], dtype=np.float32)
tx = Tensor(x, dtype=dtype.float32)
forward_jacobian = Net2(power=power)
ans = forward_jacobian(tx)
expected = (1 / power - 1) * np.log1p(x * power)
tol = 1e-6
assert (np.abs(ans.asnumpy() - expected) < tol).all()
class Net3(nn.Cell):
"""
Test class: Backward Jacobian.
"""
def __init__(self, power):
super(Net3, self).__init__()
self.bijector = msb.PowerTransform(power=power)
def construct(self, y_):
return self.bijector.inverse_log_jacobian(y_)
def test_inverse_jacobian():
power = 2
y = np.array([2.0, 3.0, 4.0, 5.0], dtype=np.float32)
ty = Tensor(y, dtype=dtype.float32)
inverse_jacobian = Net3(power=power)
ans = inverse_jacobian(ty)
expected = (power - 1) * np.log(y)
tol = 1e-6
assert (np.abs(ans.asnumpy() - expected) < tol).all()

View File

@ -0,0 +1,71 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test cases for exp"""
import mindspore.nn as nn
import mindspore.nn.probability.bijector as msb
from mindspore import Tensor
from mindspore import dtype
def test_init():
b = msb.Exp()
assert isinstance(b, msb.Bijector)
b = msb.Exp(1.0)
assert isinstance(b, msb.Bijector)
class Net(nn.Cell):
"""
Test class: forward and inverse pass of bijector.
"""
def __init__(self):
super(Net, self).__init__()
self.b1 = msb.Exp()
self.b2 = msb.Exp()
def construct(self, x_):
forward = self.b1.forward(x_)
inverse = self.b1.inverse(forward)
return x_ - inverse
def test1():
"""
Test forward and inverse pass of exp bijector.
"""
net = Net()
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
ans = net(x)
assert isinstance(ans, Tensor)
class Jacobian(nn.Cell):
"""
Test class: forward and inverse pass of bijector.
"""
def __init__(self):
super(Jacobian, self).__init__()
self.b1 = msb.Exp()
self.b2 = msb.Exp()
def construct(self, x_):
ans1 = self.b1.forward_log_jacobian(x_)
ans2 = self.b1.inverse_log_jacobian(x_)
return ans1 + ans2
def test2():
"""
Test jacobians of exp bijector.
"""
net = Jacobian()
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
ans = net(x)
assert isinstance(ans, Tensor)

View File

@ -0,0 +1,73 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test cases for powertransform"""
import mindspore.nn as nn
import mindspore.nn.probability.bijector as msb
from mindspore import Tensor
from mindspore import dtype
def test_init():
b = msb.PowerTransform()
assert isinstance(b, msb.Bijector)
b = msb.PowerTransform(1)
assert isinstance(b, msb.Bijector)
class Net(nn.Cell):
"""
Test class: forward and inverse pass of bijector.
"""
def __init__(self):
super(Net, self).__init__()
self.b1 = msb.PowerTransform(power=0)
self.b2 = msb.PowerTransform()
def construct(self, x_):
ans1 = self.b1.inverse(self.b1.forward(x_))
ans2 = self.b2.inverse(self.b2.forward(x_))
return ans1 - ans2
def test1():
"""
Test forward and inverse pass of powertransform bijector.
"""
net = Net()
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
ans = net(x)
assert isinstance(ans, Tensor)
class Jacobian(nn.Cell):
"""
Test class: forward and inverse pass of bijector.
"""
def __init__(self):
super(Jacobian, self).__init__()
self.b1 = msb.PowerTransform(power=0)
self.b2 = msb.PowerTransform()
def construct(self, x_):
ans1 = self.b1.forward_log_jacobian(x_)
ans2 = self.b2.forward_log_jacobian(x_)
ans3 = self.b1.inverse_log_jacobian(x_)
ans4 = self.b2.inverse_log_jacobian(x_)
return ans1 - ans2 + ans3 - ans4
def test2():
"""
Test jacobians of powertransform bijector.
"""
net = Jacobian()
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
ans = net(x)
assert isinstance(ans, Tensor)