added gumbel_cdf and invert bijectors

This commit is contained in:
Xun Deng 2020-10-07 14:40:56 -04:00
parent fa5c9c1528
commit f2e8e143be
7 changed files with 679 additions and 0 deletions

View File

@ -21,6 +21,8 @@ from .power_transform import PowerTransform
from .exp import Exp
from .scalar_affine import ScalarAffine
from .softplus import Softplus
from .gumbel_cdf import GumbelCDF
from .invert import Invert
__all__ = [
'Bijector',
@ -28,4 +30,6 @@ __all__ = [
'Exp',
'ScalarAffine',
'Softplus',
'GumbelCDF',
'Invert',
]

View File

@ -0,0 +1,107 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""GumbelCDF Bijector"""
from mindspore.common import dtype as mstype
from ..distribution._utils.utils import cast_to_tensor, check_greater_zero, set_param_type
from ..distribution._utils.custom_ops import exp_generic, log_generic
from .bijector import Bijector
class GumbelCDF(Bijector):
r"""
GumbelCDF Bijector.
This Bijector performs the operation:
.. math::
Y = \exp(-\exp(\frac{-(X - loc)}{scale}))
Note:
For `reverse` and `reverse_log_jacobian`, input should be in range of (0, 1).
Args:
loc (int, float, list, numpy.ndarray, Tensor): The location. Default: 0..
scale (int, float, list, numpy.ndarray, Tensor): The scale. Default: 1.0.
name (str): The name of the Bijector. Default: 'Gumbel_CDF'.
Examples:
>>> # To initialize a GumbelCDF bijector of loc 0.0, and scale 1.0.
>>> import mindspore.nn.probability.bijector as msb
>>> gum = msb.GumbelCDF(0.0, 1.0)
>>>
>>> # To use GumbelCDF bijector in a network.
>>> class net(Cell):
>>> def __init__(self):
>>> super(net, self).__init__():
>>> self.gum = msb.GumbelCDF(0.0, 1.0)
>>>
>>> def construct(self, value):
>>> # Similar calls can be made to other functions
>>> # by replacing 'forward' by the name of the function.
>>> ans1 = self.gum.forward(value)
>>> ans2 = self.gum.inverse(value)
>>> ans3 = self.gum.forward_log_jacobian(value)
>>> ans4 = self.gum.inverse_log_jacobian(value)
"""
def __init__(self,
loc=0.0,
scale=1.0,
name='GumbelCDF'):
"""
Constructor of GumbelCDF Bijector.
"""
param = dict(locals())
parameter_type = set_param_type({'loc': loc, "scale": scale}, mstype.float32)
super(GumbelCDF, self).__init__(name=name, dtype=parameter_type, param=param)
self._loc = cast_to_tensor(loc, parameter_type)
self._scale = cast_to_tensor(scale, parameter_type)
check_greater_zero(self._scale, "scale")
self.exp = exp_generic
self.log = log_generic
@property
def loc(self):
return self._loc
@property
def scale(self):
return self._scale
def extend_repr(self):
str_info = f'loc = {self.loc}, scale = {self.scale}'
return str_info
def shape_mapping(self, shape):
return shape
def _forward(self, x):
x = self._check_value(x, 'value')
z = (x - self.loc) / self.scale
return self.exp(-self.exp(-z))
def _inverse(self, y):
y = self._check_value(y, 'value')
return self.loc - self.scale * self.log(-self.log(y))
def _forward_log_jacobian(self, x):
x = self._check_value(x, 'value')
z = (x - self.loc) / self.scale
return -z - self.exp(-z) - self.log(self.scale)
def _inverse_log_jacobian(self, y):
y = self._check_value(y, 'value')
return self.log(self.scale / (-y * self.log(y)))

View File

@ -0,0 +1,75 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Invert Bijector"""
from mindspore._checkparam import Validator as validator
from .bijector import Bijector
class Invert(Bijector):
r"""
Invert Bijector.
Args:
bijector (Bijector): Base Bijector.
name (str): The name of the Bijector. Default: Invert.
Examples:
>>> # To initialize an Invert bijector.
>>> import mindspore.nn.probability.bijector as msb
>>> n = msb.Invert()
>>>
>>> # To use an Invert bijector in a network.
>>> class net(Cell):
>>> def __init__(self):
>>> super(net, self).__init__():
>>> self.inv = msb.Invert(msb.Exp())
>>>
>>> def construct(self, value):
>>> # Similar calls can be made to other functions
>>> # by replacing `forward` by the name of the function.
>>> ans1 = self.inv.forward(value)
>>> ans2 = self.inv.inverse(value)
>>> ans3 = self.inv.forward_log_jacobian(value)
>>> ans4 = self.inv.inverse_log_jacobian(value)
"""
def __init__(self,
bijector,
name='Invert'):
param = dict(locals())
validator.check_value_type('bijector', bijector, [Bijector], "Invert")
name = (name + bijector.name) if name == 'Invert' else name
super(Invert, self).__init__(is_constant_jacobian=bijector.is_constant_jacobian,
is_injective=bijector.is_injective,
dtype=bijector.dtype,
name=name,
param=param)
self._bijector = bijector
@property
def bijector(self):
return self._bijector
def inverse(self, y):
return self.bijector("forward", y)
def forward(self, x):
return self.bijector("inverse", x)
def inverse_log_jacobian(self, y):
return self.bijector("forward_log_jacobian", y)
def forward_log_jacobian(self, x):
return self.bijector("inverse_log_jacobian", x)

View File

@ -0,0 +1,108 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test cases for gumbel_cdf"""
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
import mindspore.nn.probability.bijector as msb
from mindspore import Tensor
from mindspore import dtype
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell):
"""
Test class: forward pass of bijector.
"""
def __init__(self, loc, scale):
super(Net, self).__init__()
self.bijector = msb.GumbelCDF(loc, scale)
def construct(self, x_):
return self.bijector.forward(x_)
def test_forward():
loc = np.array([0.0])
scale = np.array([[1.0], [2.0]])
forward = Net(loc, scale)
x = np.array([-2., -1., 0., 1., 2.]).astype(np.float32)
ans = forward(Tensor(x, dtype=dtype.float32))
tol = 1e-6
expected = np.exp(-np.exp(-(x - loc)/scale))
assert (np.abs(ans.asnumpy() - expected) < tol).all()
class Net1(nn.Cell):
"""
Test class: backward pass of bijector.
"""
def __init__(self, loc, scale):
super(Net1, self).__init__()
self.bijector = msb.GumbelCDF(loc, scale)
def construct(self, x_):
return self.bijector.inverse(x_)
def test_backward():
loc = np.array([0.0])
scale = np.array([[1.0], [2.0]])
backward = Net1(loc, scale)
x = np.array([0.1, 0.25, 0.5, 0.75, 0.9]).astype(np.float32)
ans = backward(Tensor(x, dtype=dtype.float32))
tol = 1e-6
expected = loc - scale * np.log(-np.log(x))
assert (np.abs(ans.asnumpy() - expected) < tol).all()
class Net2(nn.Cell):
"""
Test class: Forward Jacobian.
"""
def __init__(self, loc, scale):
super(Net2, self).__init__()
self.bijector = msb.GumbelCDF(loc, scale)
def construct(self, x_):
return self.bijector.forward_log_jacobian(x_)
def test_forward_jacobian():
loc = np.array([0.0])
scale = np.array([[1.0], [2.0]])
forward_jacobian = Net2(loc, scale)
x = np.array([-2., -1., 0., 1., 2.]).astype(np.float32)
ans = forward_jacobian(Tensor(x))
z = (x - loc) / scale
expected = -z - np.exp(-z) - np.log(scale)
tol = 1e-6
assert (np.abs(ans.asnumpy() - expected) < tol).all()
class Net3(nn.Cell):
"""
Test class: Backward Jacobian.
"""
def __init__(self, loc, scale):
super(Net3, self).__init__()
self.bijector = msb.GumbelCDF(loc, scale)
def construct(self, x_):
return self.bijector.inverse_log_jacobian(x_)
def test_backward_jacobian():
loc = np.array([0.0])
scale = np.array([[1.0], [2.0]])
backward_jacobian = Net3(loc, scale)
x = np.array([0.1, 0.2, 0.5, 0.75, 0.9]).astype(np.float32)
ans = backward_jacobian(Tensor(x))
expected = np.log(scale / (-x * np.log(x)))
tol = 1e-6
assert (np.abs(ans.asnumpy() - expected) < tol).all()

View File

@ -0,0 +1,101 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test cases for invert"""
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
import mindspore.nn.probability.bijector as msb
from mindspore import Tensor
from mindspore import dtype
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell):
"""
Test class: forward pass of bijector.
"""
def __init__(self):
super(Net, self).__init__()
self.origin = msb.ScalarAffine(scale=2.0, shift=1.0)
self.invert = msb.Invert(self.origin)
def construct(self, x_):
return self.invert.forward(x_), self.origin.inverse(x_)
def test_forward():
forward = Net()
x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32)
ans, ans2 = forward(Tensor(x, dtype=dtype.float32))
tol = 1e-6
assert (np.abs(ans.asnumpy() - ans2.asnumpy()) < tol).all()
class Net1(nn.Cell):
"""
Test class: backward pass of bijector.
"""
def __init__(self):
super(Net1, self).__init__()
self.origin = msb.ScalarAffine(scale=2.0, shift=1.0)
self.invert = msb.Invert(self.origin)
def construct(self, x_):
return self.invert.inverse(x_), self.origin.forward(x_)
def test_backward():
backward = Net1()
x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32)
ans, ans2 = backward(Tensor(x, dtype=dtype.float32))
tol = 1e-6
assert (np.abs(ans.asnumpy() - ans2.asnumpy()) < tol).all()
class Net2(nn.Cell):
"""
Test class: Forward Jacobian.
"""
def __init__(self):
super(Net2, self).__init__()
self.origin = msb.ScalarAffine(scale=2.0, shift=1.0)
self.invert = msb.Invert(self.origin)
def construct(self, x_):
return self.invert.forward_log_jacobian(x_),\
self.origin.inverse_log_jacobian(x_)
def test_forward_jacobian():
forward_jacobian = Net2()
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
ans, ans2 = forward_jacobian(x)
tol = 1e-6
assert (np.abs(ans.asnumpy() - ans2.asnumpy()) < tol).all()
class Net3(nn.Cell):
"""
Test class: Backward Jacobian.
"""
def __init__(self):
super(Net3, self).__init__()
self.origin = msb.ScalarAffine(scale=2.0, shift=1.0)
self.invert = msb.Invert(self.origin)
def construct(self, x_):
return self.invert.inverse_log_jacobian(x_),\
self.origin.forward_log_jacobian(x_)
def test_backward_jacobian():
backward_jacobian = Net3()
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
ans, ans2 = backward_jacobian(x)
tol = 1e-6
assert (np.abs(ans.asnumpy() - ans2.asnumpy()) < tol).all()

View File

@ -0,0 +1,148 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test cases for gumbel_cdf"""
import pytest
import mindspore.nn as nn
import mindspore.nn.probability.bijector as msb
from mindspore import Tensor
from mindspore import dtype
def test_init():
"""
Test initializations.
"""
b = msb.GumbelCDF()
assert isinstance(b, msb.Bijector)
b = msb.GumbelCDF(scale=1.0)
assert isinstance(b, msb.Bijector)
b = msb.GumbelCDF(loc=0.0)
assert isinstance(b, msb.Bijector)
b = msb.GumbelCDF(3.0, 4.0)
assert isinstance(b, msb.Bijector)
def test_type():
with pytest.raises(TypeError):
msb.GumbelCDF(scale='scale')
with pytest.raises(TypeError):
msb.GumbelCDF(loc='loc')
with pytest.raises(TypeError):
msb.GumbelCDF(name=0.1)
def test_invalid_scale():
"""
Test invalid scale.
"""
with pytest.raises(ValueError):
msb.GumbelCDF(scale=0.0)
with pytest.raises(ValueError):
msb.GumbelCDF(scale=-1.0)
class ForwardBackward(nn.Cell):
"""
Test class: forward and backward pass.
"""
def __init__(self):
super(ForwardBackward, self).__init__()
self.b1 = msb.GumbelCDF(1.0, 2.0)
self.b2 = msb.GumbelCDF()
def construct(self, x_):
ans1 = self.b1.inverse(self.b1.forward(x_))
ans2 = self.b2.inverse(self.b2.forward(x_))
return ans1 + ans2
def test_forward_and_backward_pass():
"""
Test forward and backward pass of ScalarAffine bijector.
"""
net = ForwardBackward()
x = Tensor([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype.float32)
ans = net(x)
assert isinstance(ans, Tensor)
class ForwardJacobian(nn.Cell):
"""
Test class: Forward log Jacobian.
"""
def __init__(self):
super(ForwardJacobian, self).__init__()
self.b1 = msb.GumbelCDF(1.0, 2.0)
self.b2 = msb.GumbelCDF()
def construct(self, x_):
ans1 = self.b1.forward_log_jacobian(x_)
ans2 = self.b2.forward_log_jacobian(x_)
return ans1 + ans2
def test_forward_jacobian():
"""
Test forward log jacobian of ScalarAffine bijector.
"""
net = ForwardJacobian()
x = Tensor([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype.float32)
ans = net(x)
assert isinstance(ans, Tensor)
class BackwardJacobian(nn.Cell):
"""
Test class: Backward log Jacobian.
"""
def __init__(self):
super(BackwardJacobian, self).__init__()
self.b1 = msb.GumbelCDF(1.0, 2.0)
self.b2 = msb.GumbelCDF()
def construct(self, x_):
ans1 = self.b1.inverse_log_jacobian(x_)
ans2 = self.b2.inverse_log_jacobian(x_)
return ans1 + ans2
def test_backward_jacobian():
"""
Test backward log jacobian of ScalarAffine bijector.
"""
net = BackwardJacobian()
x = Tensor([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype.float32)
ans = net(x)
assert isinstance(ans, Tensor)
class Net(nn.Cell):
"""
Test class: function calls going through construct.
"""
def __init__(self):
super(Net, self).__init__()
self.b1 = msb.GumbelCDF(1.0, 2.0)
self.b2 = msb.GumbelCDF()
def construct(self, x_):
ans1 = self.b1('inverse', self.b1('forward', x_))
ans2 = self.b2('inverse', self.b2('forward', x_))
ans3 = self.b1('forward_log_jacobian', x_)
ans4 = self.b2('forward_log_jacobian', x_)
ans5 = self.b1('inverse_log_jacobian', x_)
ans6 = self.b2('inverse_log_jacobian', x_)
return ans1 - ans2 + ans3 -ans4 + ans5 - ans6
def test_old_api():
"""
Test old api which goes through construct.
"""
net = Net()
x = Tensor([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype.float32)
ans = net(x)
assert isinstance(ans, Tensor)

View File

@ -0,0 +1,136 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test cases for invert"""
import pytest
import mindspore.nn as nn
import mindspore.nn.probability.bijector as msb
from mindspore import Tensor
from mindspore import dtype
def test_init():
"""
Test initializations.
"""
b = msb.Invert(msb.ScalarAffine(scale=1.0))
assert isinstance(b, msb.Bijector)
b = msb.Invert(msb.Exp())
assert isinstance(b, msb.Bijector)
def test_type():
with pytest.raises(TypeError):
msb.Invert(msb.Exp(), name=0.1)
with pytest.raises(TypeError):
msb.Invert(0.1)
def test_name():
b = msb.Invert(msb.ScalarAffine(scale=1.0))
assert b.name == 'InvertScalarAffine'
class ForwardBackward(nn.Cell):
"""
Test class: forward and backward pass.
"""
def __init__(self):
super(ForwardBackward, self).__init__()
self.inv1 = msb.Invert(msb.Exp())
self.inv2 = msb.Invert(msb.ScalarAffine())
def construct(self, x_):
ans1 = self.inv1.inverse(x_) + self.inv1.inverse(x_)
ans2 = self.inv2.inverse(x_) + self.inv2.forward(x_)
return ans1 + ans2
def test_forward_and_backward_pass():
"""
Test forward and backward pass of ScalarAffine bijector.
"""
net = ForwardBackward()
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
ans = net(x)
assert isinstance(ans, Tensor)
class ForwardJacobian(nn.Cell):
"""
Test class: Forward log Jacobian.
"""
def __init__(self):
super(ForwardJacobian, self).__init__()
self.inv1 = msb.Invert(msb.Exp())
self.inv2 = msb.Invert(msb.ScalarAffine())
def construct(self, x_):
ans1 = self.inv1.forward_log_jacobian(x_)
ans2 = self.inv2.forward_log_jacobian(x_)
return ans1 + ans2
def test_forward_jacobian():
"""
Test forward log jacobian of ScalarAffine bijector.
"""
net = ForwardJacobian()
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
ans = net(x)
assert isinstance(ans, Tensor)
class BackwardJacobian(nn.Cell):
"""
Test class: Backward log Jacobian.
"""
def __init__(self):
super(BackwardJacobian, self).__init__()
self.inv1 = msb.Invert(msb.Exp())
self.inv2 = msb.Invert(msb.ScalarAffine())
def construct(self, x_):
ans1 = self.inv1.inverse_log_jacobian(x_)
ans2 = self.inv2.inverse_log_jacobian(x_)
return ans1 + ans2
def test_backward_jacobian():
"""
Test backward log jacobian of ScalarAffine bijector.
"""
net = BackwardJacobian()
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
ans = net(x)
assert isinstance(ans, Tensor)
class Net(nn.Cell):
"""
Test class: function calls going through construct.
"""
def __init__(self):
super(Net, self).__init__()
self.b1 = msb.ScalarAffine(2.0, 1.0)
self.inv = msb.Invert(self.b1)
def construct(self, x_):
ans1 = self.inv('inverse', self.inv('forward', x_))
ans2 = self.inv('forward_log_jacobian', x_)
ans3 = self.inv('inverse_log_jacobian', x_)
return ans1 + ans2 + ans3
def test_old_api():
"""
Test old api which goes through construct.
"""
net = Net()
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
ans = net(x)
assert isinstance(ans, Tensor)