forked from mindspore-Ecosystem/mindspore
added gumbel_cdf and invert bijectors
This commit is contained in:
parent
fa5c9c1528
commit
f2e8e143be
|
@ -21,6 +21,8 @@ from .power_transform import PowerTransform
|
|||
from .exp import Exp
|
||||
from .scalar_affine import ScalarAffine
|
||||
from .softplus import Softplus
|
||||
from .gumbel_cdf import GumbelCDF
|
||||
from .invert import Invert
|
||||
|
||||
__all__ = [
|
||||
'Bijector',
|
||||
|
@ -28,4 +30,6 @@ __all__ = [
|
|||
'Exp',
|
||||
'ScalarAffine',
|
||||
'Softplus',
|
||||
'GumbelCDF',
|
||||
'Invert',
|
||||
]
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""GumbelCDF Bijector"""
|
||||
from mindspore.common import dtype as mstype
|
||||
from ..distribution._utils.utils import cast_to_tensor, check_greater_zero, set_param_type
|
||||
from ..distribution._utils.custom_ops import exp_generic, log_generic
|
||||
from .bijector import Bijector
|
||||
|
||||
|
||||
class GumbelCDF(Bijector):
|
||||
r"""
|
||||
GumbelCDF Bijector.
|
||||
This Bijector performs the operation:
|
||||
|
||||
.. math::
|
||||
Y = \exp(-\exp(\frac{-(X - loc)}{scale}))
|
||||
|
||||
Note:
|
||||
For `reverse` and `reverse_log_jacobian`, input should be in range of (0, 1).
|
||||
|
||||
Args:
|
||||
loc (int, float, list, numpy.ndarray, Tensor): The location. Default: 0..
|
||||
scale (int, float, list, numpy.ndarray, Tensor): The scale. Default: 1.0.
|
||||
name (str): The name of the Bijector. Default: 'Gumbel_CDF'.
|
||||
|
||||
Examples:
|
||||
>>> # To initialize a GumbelCDF bijector of loc 0.0, and scale 1.0.
|
||||
>>> import mindspore.nn.probability.bijector as msb
|
||||
>>> gum = msb.GumbelCDF(0.0, 1.0)
|
||||
>>>
|
||||
>>> # To use GumbelCDF bijector in a network.
|
||||
>>> class net(Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(net, self).__init__():
|
||||
>>> self.gum = msb.GumbelCDF(0.0, 1.0)
|
||||
>>>
|
||||
>>> def construct(self, value):
|
||||
>>> # Similar calls can be made to other functions
|
||||
>>> # by replacing 'forward' by the name of the function.
|
||||
>>> ans1 = self.gum.forward(value)
|
||||
>>> ans2 = self.gum.inverse(value)
|
||||
>>> ans3 = self.gum.forward_log_jacobian(value)
|
||||
>>> ans4 = self.gum.inverse_log_jacobian(value)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
loc=0.0,
|
||||
scale=1.0,
|
||||
name='GumbelCDF'):
|
||||
"""
|
||||
Constructor of GumbelCDF Bijector.
|
||||
"""
|
||||
param = dict(locals())
|
||||
parameter_type = set_param_type({'loc': loc, "scale": scale}, mstype.float32)
|
||||
super(GumbelCDF, self).__init__(name=name, dtype=parameter_type, param=param)
|
||||
self._loc = cast_to_tensor(loc, parameter_type)
|
||||
self._scale = cast_to_tensor(scale, parameter_type)
|
||||
check_greater_zero(self._scale, "scale")
|
||||
|
||||
self.exp = exp_generic
|
||||
self.log = log_generic
|
||||
|
||||
|
||||
@property
|
||||
def loc(self):
|
||||
return self._loc
|
||||
|
||||
@property
|
||||
def scale(self):
|
||||
return self._scale
|
||||
|
||||
def extend_repr(self):
|
||||
str_info = f'loc = {self.loc}, scale = {self.scale}'
|
||||
return str_info
|
||||
|
||||
def shape_mapping(self, shape):
|
||||
return shape
|
||||
|
||||
def _forward(self, x):
|
||||
x = self._check_value(x, 'value')
|
||||
z = (x - self.loc) / self.scale
|
||||
return self.exp(-self.exp(-z))
|
||||
|
||||
def _inverse(self, y):
|
||||
y = self._check_value(y, 'value')
|
||||
return self.loc - self.scale * self.log(-self.log(y))
|
||||
|
||||
def _forward_log_jacobian(self, x):
|
||||
x = self._check_value(x, 'value')
|
||||
z = (x - self.loc) / self.scale
|
||||
return -z - self.exp(-z) - self.log(self.scale)
|
||||
|
||||
def _inverse_log_jacobian(self, y):
|
||||
y = self._check_value(y, 'value')
|
||||
return self.log(self.scale / (-y * self.log(y)))
|
|
@ -0,0 +1,75 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Invert Bijector"""
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from .bijector import Bijector
|
||||
|
||||
|
||||
class Invert(Bijector):
|
||||
r"""
|
||||
Invert Bijector.
|
||||
|
||||
Args:
|
||||
bijector (Bijector): Base Bijector.
|
||||
name (str): The name of the Bijector. Default: Invert.
|
||||
|
||||
Examples:
|
||||
>>> # To initialize an Invert bijector.
|
||||
>>> import mindspore.nn.probability.bijector as msb
|
||||
>>> n = msb.Invert()
|
||||
>>>
|
||||
>>> # To use an Invert bijector in a network.
|
||||
>>> class net(Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(net, self).__init__():
|
||||
>>> self.inv = msb.Invert(msb.Exp())
|
||||
>>>
|
||||
>>> def construct(self, value):
|
||||
>>> # Similar calls can be made to other functions
|
||||
>>> # by replacing `forward` by the name of the function.
|
||||
>>> ans1 = self.inv.forward(value)
|
||||
>>> ans2 = self.inv.inverse(value)
|
||||
>>> ans3 = self.inv.forward_log_jacobian(value)
|
||||
>>> ans4 = self.inv.inverse_log_jacobian(value)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
bijector,
|
||||
name='Invert'):
|
||||
param = dict(locals())
|
||||
validator.check_value_type('bijector', bijector, [Bijector], "Invert")
|
||||
name = (name + bijector.name) if name == 'Invert' else name
|
||||
super(Invert, self).__init__(is_constant_jacobian=bijector.is_constant_jacobian,
|
||||
is_injective=bijector.is_injective,
|
||||
dtype=bijector.dtype,
|
||||
name=name,
|
||||
param=param)
|
||||
self._bijector = bijector
|
||||
|
||||
@property
|
||||
def bijector(self):
|
||||
return self._bijector
|
||||
|
||||
def inverse(self, y):
|
||||
return self.bijector("forward", y)
|
||||
|
||||
def forward(self, x):
|
||||
return self.bijector("inverse", x)
|
||||
|
||||
def inverse_log_jacobian(self, y):
|
||||
return self.bijector("forward_log_jacobian", y)
|
||||
|
||||
def forward_log_jacobian(self, x):
|
||||
return self.bijector("inverse_log_jacobian", x)
|
|
@ -0,0 +1,108 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test cases for gumbel_cdf"""
|
||||
import numpy as np
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore.nn.probability.bijector as msb
|
||||
from mindspore import Tensor
|
||||
from mindspore import dtype
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
class Net(nn.Cell):
|
||||
"""
|
||||
Test class: forward pass of bijector.
|
||||
"""
|
||||
def __init__(self, loc, scale):
|
||||
super(Net, self).__init__()
|
||||
self.bijector = msb.GumbelCDF(loc, scale)
|
||||
|
||||
def construct(self, x_):
|
||||
return self.bijector.forward(x_)
|
||||
|
||||
def test_forward():
|
||||
loc = np.array([0.0])
|
||||
scale = np.array([[1.0], [2.0]])
|
||||
forward = Net(loc, scale)
|
||||
x = np.array([-2., -1., 0., 1., 2.]).astype(np.float32)
|
||||
ans = forward(Tensor(x, dtype=dtype.float32))
|
||||
tol = 1e-6
|
||||
expected = np.exp(-np.exp(-(x - loc)/scale))
|
||||
assert (np.abs(ans.asnumpy() - expected) < tol).all()
|
||||
|
||||
class Net1(nn.Cell):
|
||||
"""
|
||||
Test class: backward pass of bijector.
|
||||
"""
|
||||
def __init__(self, loc, scale):
|
||||
super(Net1, self).__init__()
|
||||
self.bijector = msb.GumbelCDF(loc, scale)
|
||||
|
||||
def construct(self, x_):
|
||||
return self.bijector.inverse(x_)
|
||||
|
||||
def test_backward():
|
||||
loc = np.array([0.0])
|
||||
scale = np.array([[1.0], [2.0]])
|
||||
backward = Net1(loc, scale)
|
||||
x = np.array([0.1, 0.25, 0.5, 0.75, 0.9]).astype(np.float32)
|
||||
ans = backward(Tensor(x, dtype=dtype.float32))
|
||||
tol = 1e-6
|
||||
expected = loc - scale * np.log(-np.log(x))
|
||||
assert (np.abs(ans.asnumpy() - expected) < tol).all()
|
||||
|
||||
class Net2(nn.Cell):
|
||||
"""
|
||||
Test class: Forward Jacobian.
|
||||
"""
|
||||
def __init__(self, loc, scale):
|
||||
super(Net2, self).__init__()
|
||||
self.bijector = msb.GumbelCDF(loc, scale)
|
||||
|
||||
def construct(self, x_):
|
||||
return self.bijector.forward_log_jacobian(x_)
|
||||
|
||||
def test_forward_jacobian():
|
||||
loc = np.array([0.0])
|
||||
scale = np.array([[1.0], [2.0]])
|
||||
forward_jacobian = Net2(loc, scale)
|
||||
x = np.array([-2., -1., 0., 1., 2.]).astype(np.float32)
|
||||
ans = forward_jacobian(Tensor(x))
|
||||
z = (x - loc) / scale
|
||||
expected = -z - np.exp(-z) - np.log(scale)
|
||||
tol = 1e-6
|
||||
assert (np.abs(ans.asnumpy() - expected) < tol).all()
|
||||
|
||||
class Net3(nn.Cell):
|
||||
"""
|
||||
Test class: Backward Jacobian.
|
||||
"""
|
||||
def __init__(self, loc, scale):
|
||||
super(Net3, self).__init__()
|
||||
self.bijector = msb.GumbelCDF(loc, scale)
|
||||
|
||||
def construct(self, x_):
|
||||
return self.bijector.inverse_log_jacobian(x_)
|
||||
|
||||
def test_backward_jacobian():
|
||||
loc = np.array([0.0])
|
||||
scale = np.array([[1.0], [2.0]])
|
||||
backward_jacobian = Net3(loc, scale)
|
||||
x = np.array([0.1, 0.2, 0.5, 0.75, 0.9]).astype(np.float32)
|
||||
ans = backward_jacobian(Tensor(x))
|
||||
expected = np.log(scale / (-x * np.log(x)))
|
||||
tol = 1e-6
|
||||
assert (np.abs(ans.asnumpy() - expected) < tol).all()
|
|
@ -0,0 +1,101 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test cases for invert"""
|
||||
import numpy as np
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore.nn.probability.bijector as msb
|
||||
from mindspore import Tensor
|
||||
from mindspore import dtype
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
class Net(nn.Cell):
|
||||
"""
|
||||
Test class: forward pass of bijector.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.origin = msb.ScalarAffine(scale=2.0, shift=1.0)
|
||||
self.invert = msb.Invert(self.origin)
|
||||
|
||||
def construct(self, x_):
|
||||
return self.invert.forward(x_), self.origin.inverse(x_)
|
||||
|
||||
def test_forward():
|
||||
forward = Net()
|
||||
x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32)
|
||||
ans, ans2 = forward(Tensor(x, dtype=dtype.float32))
|
||||
tol = 1e-6
|
||||
assert (np.abs(ans.asnumpy() - ans2.asnumpy()) < tol).all()
|
||||
|
||||
class Net1(nn.Cell):
|
||||
"""
|
||||
Test class: backward pass of bijector.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(Net1, self).__init__()
|
||||
self.origin = msb.ScalarAffine(scale=2.0, shift=1.0)
|
||||
self.invert = msb.Invert(self.origin)
|
||||
|
||||
def construct(self, x_):
|
||||
return self.invert.inverse(x_), self.origin.forward(x_)
|
||||
|
||||
def test_backward():
|
||||
backward = Net1()
|
||||
x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32)
|
||||
ans, ans2 = backward(Tensor(x, dtype=dtype.float32))
|
||||
tol = 1e-6
|
||||
assert (np.abs(ans.asnumpy() - ans2.asnumpy()) < tol).all()
|
||||
|
||||
class Net2(nn.Cell):
|
||||
"""
|
||||
Test class: Forward Jacobian.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(Net2, self).__init__()
|
||||
self.origin = msb.ScalarAffine(scale=2.0, shift=1.0)
|
||||
self.invert = msb.Invert(self.origin)
|
||||
|
||||
def construct(self, x_):
|
||||
return self.invert.forward_log_jacobian(x_),\
|
||||
self.origin.inverse_log_jacobian(x_)
|
||||
|
||||
def test_forward_jacobian():
|
||||
forward_jacobian = Net2()
|
||||
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
|
||||
ans, ans2 = forward_jacobian(x)
|
||||
tol = 1e-6
|
||||
assert (np.abs(ans.asnumpy() - ans2.asnumpy()) < tol).all()
|
||||
|
||||
class Net3(nn.Cell):
|
||||
"""
|
||||
Test class: Backward Jacobian.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(Net3, self).__init__()
|
||||
self.origin = msb.ScalarAffine(scale=2.0, shift=1.0)
|
||||
self.invert = msb.Invert(self.origin)
|
||||
|
||||
def construct(self, x_):
|
||||
return self.invert.inverse_log_jacobian(x_),\
|
||||
self.origin.forward_log_jacobian(x_)
|
||||
|
||||
def test_backward_jacobian():
|
||||
backward_jacobian = Net3()
|
||||
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
|
||||
ans, ans2 = backward_jacobian(x)
|
||||
tol = 1e-6
|
||||
assert (np.abs(ans.asnumpy() - ans2.asnumpy()) < tol).all()
|
|
@ -0,0 +1,148 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test cases for gumbel_cdf"""
|
||||
import pytest
|
||||
import mindspore.nn as nn
|
||||
import mindspore.nn.probability.bijector as msb
|
||||
from mindspore import Tensor
|
||||
from mindspore import dtype
|
||||
|
||||
def test_init():
|
||||
"""
|
||||
Test initializations.
|
||||
"""
|
||||
b = msb.GumbelCDF()
|
||||
assert isinstance(b, msb.Bijector)
|
||||
b = msb.GumbelCDF(scale=1.0)
|
||||
assert isinstance(b, msb.Bijector)
|
||||
b = msb.GumbelCDF(loc=0.0)
|
||||
assert isinstance(b, msb.Bijector)
|
||||
b = msb.GumbelCDF(3.0, 4.0)
|
||||
assert isinstance(b, msb.Bijector)
|
||||
|
||||
def test_type():
|
||||
with pytest.raises(TypeError):
|
||||
msb.GumbelCDF(scale='scale')
|
||||
with pytest.raises(TypeError):
|
||||
msb.GumbelCDF(loc='loc')
|
||||
with pytest.raises(TypeError):
|
||||
msb.GumbelCDF(name=0.1)
|
||||
|
||||
def test_invalid_scale():
|
||||
"""
|
||||
Test invalid scale.
|
||||
"""
|
||||
with pytest.raises(ValueError):
|
||||
msb.GumbelCDF(scale=0.0)
|
||||
with pytest.raises(ValueError):
|
||||
msb.GumbelCDF(scale=-1.0)
|
||||
|
||||
class ForwardBackward(nn.Cell):
|
||||
"""
|
||||
Test class: forward and backward pass.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(ForwardBackward, self).__init__()
|
||||
self.b1 = msb.GumbelCDF(1.0, 2.0)
|
||||
self.b2 = msb.GumbelCDF()
|
||||
|
||||
def construct(self, x_):
|
||||
ans1 = self.b1.inverse(self.b1.forward(x_))
|
||||
ans2 = self.b2.inverse(self.b2.forward(x_))
|
||||
return ans1 + ans2
|
||||
|
||||
def test_forward_and_backward_pass():
|
||||
"""
|
||||
Test forward and backward pass of ScalarAffine bijector.
|
||||
"""
|
||||
net = ForwardBackward()
|
||||
x = Tensor([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype.float32)
|
||||
ans = net(x)
|
||||
assert isinstance(ans, Tensor)
|
||||
|
||||
class ForwardJacobian(nn.Cell):
|
||||
"""
|
||||
Test class: Forward log Jacobian.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(ForwardJacobian, self).__init__()
|
||||
self.b1 = msb.GumbelCDF(1.0, 2.0)
|
||||
self.b2 = msb.GumbelCDF()
|
||||
|
||||
def construct(self, x_):
|
||||
ans1 = self.b1.forward_log_jacobian(x_)
|
||||
ans2 = self.b2.forward_log_jacobian(x_)
|
||||
return ans1 + ans2
|
||||
|
||||
def test_forward_jacobian():
|
||||
"""
|
||||
Test forward log jacobian of ScalarAffine bijector.
|
||||
"""
|
||||
net = ForwardJacobian()
|
||||
x = Tensor([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype.float32)
|
||||
ans = net(x)
|
||||
assert isinstance(ans, Tensor)
|
||||
|
||||
|
||||
class BackwardJacobian(nn.Cell):
|
||||
"""
|
||||
Test class: Backward log Jacobian.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(BackwardJacobian, self).__init__()
|
||||
self.b1 = msb.GumbelCDF(1.0, 2.0)
|
||||
self.b2 = msb.GumbelCDF()
|
||||
|
||||
def construct(self, x_):
|
||||
ans1 = self.b1.inverse_log_jacobian(x_)
|
||||
ans2 = self.b2.inverse_log_jacobian(x_)
|
||||
return ans1 + ans2
|
||||
|
||||
def test_backward_jacobian():
|
||||
"""
|
||||
Test backward log jacobian of ScalarAffine bijector.
|
||||
"""
|
||||
net = BackwardJacobian()
|
||||
x = Tensor([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype.float32)
|
||||
ans = net(x)
|
||||
assert isinstance(ans, Tensor)
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
"""
|
||||
Test class: function calls going through construct.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.b1 = msb.GumbelCDF(1.0, 2.0)
|
||||
self.b2 = msb.GumbelCDF()
|
||||
|
||||
def construct(self, x_):
|
||||
ans1 = self.b1('inverse', self.b1('forward', x_))
|
||||
ans2 = self.b2('inverse', self.b2('forward', x_))
|
||||
ans3 = self.b1('forward_log_jacobian', x_)
|
||||
ans4 = self.b2('forward_log_jacobian', x_)
|
||||
ans5 = self.b1('inverse_log_jacobian', x_)
|
||||
ans6 = self.b2('inverse_log_jacobian', x_)
|
||||
return ans1 - ans2 + ans3 -ans4 + ans5 - ans6
|
||||
|
||||
def test_old_api():
|
||||
"""
|
||||
Test old api which goes through construct.
|
||||
"""
|
||||
net = Net()
|
||||
x = Tensor([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype.float32)
|
||||
ans = net(x)
|
||||
assert isinstance(ans, Tensor)
|
|
@ -0,0 +1,136 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test cases for invert"""
|
||||
import pytest
|
||||
import mindspore.nn as nn
|
||||
import mindspore.nn.probability.bijector as msb
|
||||
from mindspore import Tensor
|
||||
from mindspore import dtype
|
||||
|
||||
def test_init():
|
||||
"""
|
||||
Test initializations.
|
||||
"""
|
||||
b = msb.Invert(msb.ScalarAffine(scale=1.0))
|
||||
assert isinstance(b, msb.Bijector)
|
||||
b = msb.Invert(msb.Exp())
|
||||
assert isinstance(b, msb.Bijector)
|
||||
|
||||
|
||||
def test_type():
|
||||
with pytest.raises(TypeError):
|
||||
msb.Invert(msb.Exp(), name=0.1)
|
||||
with pytest.raises(TypeError):
|
||||
msb.Invert(0.1)
|
||||
|
||||
def test_name():
|
||||
b = msb.Invert(msb.ScalarAffine(scale=1.0))
|
||||
assert b.name == 'InvertScalarAffine'
|
||||
|
||||
class ForwardBackward(nn.Cell):
|
||||
"""
|
||||
Test class: forward and backward pass.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(ForwardBackward, self).__init__()
|
||||
self.inv1 = msb.Invert(msb.Exp())
|
||||
self.inv2 = msb.Invert(msb.ScalarAffine())
|
||||
|
||||
def construct(self, x_):
|
||||
ans1 = self.inv1.inverse(x_) + self.inv1.inverse(x_)
|
||||
ans2 = self.inv2.inverse(x_) + self.inv2.forward(x_)
|
||||
return ans1 + ans2
|
||||
|
||||
def test_forward_and_backward_pass():
|
||||
"""
|
||||
Test forward and backward pass of ScalarAffine bijector.
|
||||
"""
|
||||
net = ForwardBackward()
|
||||
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
|
||||
ans = net(x)
|
||||
assert isinstance(ans, Tensor)
|
||||
|
||||
class ForwardJacobian(nn.Cell):
|
||||
"""
|
||||
Test class: Forward log Jacobian.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(ForwardJacobian, self).__init__()
|
||||
self.inv1 = msb.Invert(msb.Exp())
|
||||
self.inv2 = msb.Invert(msb.ScalarAffine())
|
||||
|
||||
|
||||
def construct(self, x_):
|
||||
ans1 = self.inv1.forward_log_jacobian(x_)
|
||||
ans2 = self.inv2.forward_log_jacobian(x_)
|
||||
return ans1 + ans2
|
||||
|
||||
def test_forward_jacobian():
|
||||
"""
|
||||
Test forward log jacobian of ScalarAffine bijector.
|
||||
"""
|
||||
net = ForwardJacobian()
|
||||
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
|
||||
ans = net(x)
|
||||
assert isinstance(ans, Tensor)
|
||||
|
||||
|
||||
class BackwardJacobian(nn.Cell):
|
||||
"""
|
||||
Test class: Backward log Jacobian.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(BackwardJacobian, self).__init__()
|
||||
self.inv1 = msb.Invert(msb.Exp())
|
||||
self.inv2 = msb.Invert(msb.ScalarAffine())
|
||||
|
||||
def construct(self, x_):
|
||||
ans1 = self.inv1.inverse_log_jacobian(x_)
|
||||
ans2 = self.inv2.inverse_log_jacobian(x_)
|
||||
return ans1 + ans2
|
||||
|
||||
def test_backward_jacobian():
|
||||
"""
|
||||
Test backward log jacobian of ScalarAffine bijector.
|
||||
"""
|
||||
net = BackwardJacobian()
|
||||
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
|
||||
ans = net(x)
|
||||
assert isinstance(ans, Tensor)
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
"""
|
||||
Test class: function calls going through construct.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.b1 = msb.ScalarAffine(2.0, 1.0)
|
||||
self.inv = msb.Invert(self.b1)
|
||||
|
||||
def construct(self, x_):
|
||||
ans1 = self.inv('inverse', self.inv('forward', x_))
|
||||
ans2 = self.inv('forward_log_jacobian', x_)
|
||||
ans3 = self.inv('inverse_log_jacobian', x_)
|
||||
return ans1 + ans2 + ans3
|
||||
|
||||
def test_old_api():
|
||||
"""
|
||||
Test old api which goes through construct.
|
||||
"""
|
||||
net = Net()
|
||||
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
|
||||
ans = net(x)
|
||||
assert isinstance(ans, Tensor)
|
Loading…
Reference in New Issue