forked from mindspore-Ecosystem/mindspore
Added Bijector TransformDistribution base classes and two instances: power and exp bijectors
This commit is contained in:
parent
64b0cb4f95
commit
b9c8c0b439
|
@ -35,5 +35,4 @@ __all__.extend(metrics.__all__)
|
|||
__all__.extend(wrap.__all__)
|
||||
__all__.extend(sparse.__all__)
|
||||
|
||||
|
||||
__all__.sort()
|
||||
|
|
|
@ -18,4 +18,5 @@ Probability.
|
|||
The high-level components used to construct the probabilistic network.
|
||||
"""
|
||||
|
||||
from . import bijector
|
||||
from . import distribution
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Bijector.
|
||||
|
||||
The high-level components(Bijectors) used to construct the probabilistic network.
|
||||
"""
|
||||
|
||||
from .bijector import Bijector
|
||||
from .power_transform import PowerTransform
|
||||
from .exp import Exp
|
||||
|
||||
__all__ = ['Bijector',
|
||||
'PowerTransform',
|
||||
'Exp']
|
|
@ -0,0 +1,130 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Bijector"""
|
||||
from mindspore.nn.cell import Cell
|
||||
from ..distribution import Distribution
|
||||
from ..distribution import TransformedDistribution
|
||||
|
||||
class Bijector(Cell):
|
||||
"""
|
||||
Bijecotr class.
|
||||
|
||||
Args:
|
||||
is_constant_jacobian (bool): if the bijector has constant derivative. Default: False.
|
||||
is_injective (bool): if the bijector is an one-to-one mapping. Default: True.
|
||||
name (str): name of the bijector. Default: None.
|
||||
dtype (mstype): type of the distribution the bijector can operate on. Default: None.
|
||||
param (dict): parameters used to initialize the bijector. Default: None.
|
||||
"""
|
||||
def __init__(self,
|
||||
is_constant_jacobian=False,
|
||||
is_injective=True,
|
||||
name=None,
|
||||
dtype=None,
|
||||
param=None):
|
||||
|
||||
"""
|
||||
Constructor of bijector class.
|
||||
"""
|
||||
super(Bijector, self).__init__()
|
||||
self._name = name
|
||||
self._dtype = dtype
|
||||
self._parameters = {}
|
||||
# parsing parameters
|
||||
for k in param.keys():
|
||||
if not(k == 'self' or k.startswith('_')):
|
||||
self._parameters[k] = param[k]
|
||||
self._is_constant_jacobian = is_constant_jacobian
|
||||
self._is_injective = is_injective
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
@property
|
||||
def parameters(self):
|
||||
return self._parameters
|
||||
|
||||
@property
|
||||
def is_constant_jacobian(self):
|
||||
return self._is_constant_jacobian
|
||||
|
||||
@property
|
||||
def is_injective(self):
|
||||
return self._is_injective
|
||||
|
||||
def forward(self, *args):
|
||||
"""
|
||||
Forward transformation: transform the input value to another distribution.
|
||||
"""
|
||||
return self._forward(*args)
|
||||
|
||||
def inverse(self, *args):
|
||||
"""
|
||||
Inverse transformation: transform the input value back to the original distribution.
|
||||
"""
|
||||
return self._inverse(*args)
|
||||
|
||||
def forward_log_jacobian(self, *args):
|
||||
"""
|
||||
Logarithm of the derivative of forward transformation.
|
||||
"""
|
||||
return self._forward_log_jacobian(*args)
|
||||
|
||||
def inverse_log_jacobian(self, *args):
|
||||
"""
|
||||
Logarithm of the derivative of forward transformation.
|
||||
"""
|
||||
return self._inverse_log_jacobian(*args)
|
||||
|
||||
def __call__(self, *args):
|
||||
"""
|
||||
Call Bijector directly.
|
||||
This __call__ may go into two directions:
|
||||
If args[0] is a distribution instance, the call will generate a new distribution derived from
|
||||
the input distribution.
|
||||
Otherwise, input[0] should be the name of a bijector function, e.g. "forward", then this call will
|
||||
go in the construct and invoke the correstpoding bijector function.
|
||||
|
||||
Args:
|
||||
*args: args[0] shall be either a distribution or the name of a bijector function.
|
||||
"""
|
||||
if isinstance(args[0], Distribution):
|
||||
return TransformedDistribution(self, args[0])
|
||||
return super(Bijector, self).__call__(*args)
|
||||
|
||||
def construct(self, name, *args):
|
||||
"""
|
||||
Override construct in Cell.
|
||||
|
||||
Args:
|
||||
*inputs: inputs[0] is always the name of a function.
|
||||
|
||||
Notes:
|
||||
Always raise RuntimeError as Distribution should not be called directly.
|
||||
"""
|
||||
if name == 'forward':
|
||||
return self.forward(*args)
|
||||
if name == 'inverse':
|
||||
return self.inverse(*args)
|
||||
if name == 'forward_log_jacobian':
|
||||
return self.forward_log_jacobian(*args)
|
||||
if name == 'inverse_log_jacobian':
|
||||
return self.inverse_log_jacobian(*args)
|
||||
return None
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Power Bijector"""
|
||||
from .power_transform import PowerTransform
|
||||
|
||||
class Exp(PowerTransform):
|
||||
r"""
|
||||
Exponential Bijector.
|
||||
This Bijector performs the operation: Y = exp(x).
|
||||
|
||||
Examples:
|
||||
>>> # To initialize a Exp bijector
|
||||
>>> import mindspore.nn.probability.bijector as msb
|
||||
>>> n = msb.Exp()
|
||||
>>>
|
||||
>>> # To use Exp distribution in a network
|
||||
>>> class net(Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(net, self).__init__():
|
||||
>>> self.e1 = msb.Exp()
|
||||
>>>
|
||||
>>> def construct(self, value):
|
||||
>>>
|
||||
>>> # Similar calls can be made to other probability functions
|
||||
>>> # by replacing 'forward' with the name of the function
|
||||
>>> ans1 = self.e1.forward(value)
|
||||
>>> ans2 = self.e1.backward(value)
|
||||
"""
|
||||
def __init__(self,
|
||||
name='Exp'):
|
||||
param = dict(locals())
|
||||
super(Exp, self).__init__(name=name, param=param)
|
|
@ -0,0 +1,124 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Power Bijector"""
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from .bijector import Bijector
|
||||
|
||||
class PowerTransform(Bijector):
|
||||
r"""
|
||||
Power Bijector.
|
||||
This Bijector performs the operation: Y = g(X) = (1 + X * c)^(1 / c), X >= -1 / c, where c is power.
|
||||
|
||||
The power transform maps inputs from `[-1/c, inf]` to `[0, inf]`.
|
||||
|
||||
This bijector is equivalent to the `Exp` bijector when `c=0`
|
||||
|
||||
Args:
|
||||
power (int or float): scale factor. Default: 0.
|
||||
|
||||
Examples:
|
||||
>>> # To initialize a PowerTransform bijector of power 0.5
|
||||
>>> import mindspore.nn.probability.bijector as msb
|
||||
>>> n = msb.PowerTransform(0.5)
|
||||
>>>
|
||||
>>> # To use PowerTransform distribution in a network
|
||||
>>> class net(Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(net, self).__init__():
|
||||
>>> self.p1 = msb.PowerTransform(0.5)
|
||||
>>>
|
||||
>>> def construct(self, value):
|
||||
>>>
|
||||
>>> # Similar calls can be made to other probability functions
|
||||
>>> # by replacing 'forward' with the name of the function
|
||||
>>> ans = self.p1.forward(, value)
|
||||
"""
|
||||
def __init__(self,
|
||||
power=0,
|
||||
name='PowerTransform',
|
||||
param=None):
|
||||
param = dict(locals()) if param is None else param
|
||||
super(PowerTransform, self).__init__(name=name, param=param)
|
||||
validator.check_value_type('power', power, [int, float], self.name)
|
||||
self._power = power
|
||||
self.pow = P.Pow()
|
||||
self.exp = P.Exp()
|
||||
self.log = P.Log()
|
||||
self.log1p = self._log1p_by_step
|
||||
self.expm1 = self._expm1_by_step
|
||||
|
||||
def _log1p_by_step(self, x):
|
||||
"""
|
||||
Log1p ops on GPU device or when device_target == GPU.
|
||||
"""
|
||||
return self.log(x + 1.0)
|
||||
|
||||
def _expm1_by_step(self, x):
|
||||
"""
|
||||
Expm1 ops on GPU device or when device_target == GPU.
|
||||
"""
|
||||
return self.exp(x) - 1.0
|
||||
|
||||
@property
|
||||
def power(self):
|
||||
return self._power
|
||||
|
||||
def extend_repr(self):
|
||||
str_info = f'power = {self.power}'
|
||||
return str_info
|
||||
|
||||
def shape_mapping(self, shape):
|
||||
return shape
|
||||
|
||||
def _forward(self, x):
|
||||
if self.power == 0:
|
||||
return self.exp(x)
|
||||
return self.exp(self.log1p(x * self.power) / self.power)
|
||||
|
||||
def _inverse(self, y):
|
||||
if self.power == 0:
|
||||
return self.log(y)
|
||||
return self.expm1(self.log(y) * self.power) / self.power
|
||||
|
||||
def _forward_log_jacobian(self, x):
|
||||
r"""
|
||||
.. math:
|
||||
if c == 0:
|
||||
f(x) = e^x
|
||||
f'(x) = e^x
|
||||
\log(f'(x)) = \log(e^x) = x
|
||||
else:
|
||||
f(x) = e^\frac{\log(xc + 1)}{c}
|
||||
f'(x) = e^\frac{\log(xc + 1)}{c} * \frac{1}{xc + 1}
|
||||
\log(f'(x)) = (\frac{1}{c} - 1) * \log(xc + 1)
|
||||
"""
|
||||
if self.power == 0:
|
||||
return x
|
||||
return (1. / self.power - 1) * self.log1p(x * self.power)
|
||||
|
||||
def _inverse_log_jacobian(self, y):
|
||||
r"""
|
||||
.. math:
|
||||
if c == 0:
|
||||
f(x) = \log(x)
|
||||
f'(x) = \frac{1}{x}
|
||||
\log(f'(x)) = \log(\frac{1}{x}) = -\log(x)
|
||||
else:
|
||||
f(x) = \frac{e^\log(y)*c + 1}{c}
|
||||
f'(x) = \frac{e^c\log(y)}{y}
|
||||
\log(f'(x)) = \log(\frac{e^c\log(y)}{y}) = (c-1) * \log(y)
|
||||
"""
|
||||
return (self.power - 1) * self.log(y)
|
|
@ -19,6 +19,7 @@ The high-level components(Distributions) used to construct the probabilistic net
|
|||
"""
|
||||
|
||||
from .distribution import Distribution
|
||||
from .transformed_distribution import TransformedDistribution
|
||||
from .normal import Normal
|
||||
from .bernoulli import Bernoulli
|
||||
from .exponential import Exponential
|
||||
|
@ -26,6 +27,7 @@ from .uniform import Uniform
|
|||
from .geometric import Geometric
|
||||
|
||||
__all__ = ['Distribution',
|
||||
'TransformedDistribution',
|
||||
'Normal',
|
||||
'Bernoulli',
|
||||
'Exponential',
|
||||
|
|
|
@ -0,0 +1,94 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Transformed Distribution"""
|
||||
from mindspore.ops import operations as P
|
||||
from .distribution import Distribution
|
||||
|
||||
class TransformedDistribution(Distribution):
|
||||
"""
|
||||
Transformed Distribution.
|
||||
This class contains a bijector and a distribution and transforms the original distribution
|
||||
to a new distribution through the operation defined by the bijector.
|
||||
|
||||
Args:
|
||||
bijector (Bijector): transformation to perform.
|
||||
distribution (Distribution): The original distribution.
|
||||
name (str): name of the transformed distribution. Default: transformed_distribution.
|
||||
|
||||
Note:
|
||||
The arguments used to initialize the original distribution cannot be None.
|
||||
For example, mynormal = nn.Normal(dtype=dtyple.float32) cannot be used to initialized a
|
||||
TransformedDistribution since mean and sd are not specified.
|
||||
"""
|
||||
def __init__(self,
|
||||
bijector,
|
||||
distribution,
|
||||
name="transformed_distribution"):
|
||||
"""
|
||||
Constructor of transformed_distribution class.
|
||||
"""
|
||||
param = dict(locals())
|
||||
super(TransformedDistribution, self).__init__(distribution.dtype, name, param)
|
||||
self._bijector = bijector
|
||||
self._distribution = distribution
|
||||
self._is_linear_transformation = bijector.is_constant_jacobian
|
||||
self.exp = P.Exp()
|
||||
|
||||
@property
|
||||
def bijector(self):
|
||||
return self._bijector
|
||||
|
||||
@property
|
||||
def distribution(self):
|
||||
return self._distribution
|
||||
|
||||
@property
|
||||
def is_linear_transformation(self):
|
||||
return self._is_linear_transformation
|
||||
|
||||
def _cdf(self, value):
|
||||
r"""
|
||||
.. math::
|
||||
Y = g(X)
|
||||
P(Y <= a) = P(X <= g^{-1}(a))
|
||||
"""
|
||||
inverse_value = self.bijector.inverse(value)
|
||||
return self.distribution.cdf(inverse_value)
|
||||
|
||||
def _log_prob(self, value):
|
||||
r"""
|
||||
.. math::
|
||||
Y = g(X)
|
||||
Py(a) = Px(g^{-1}(a)) * (g^{-1})'(a)
|
||||
\log(Py(a)) = \log(Px(g^{-1}(a))) + \log((g^{-1})'(a))
|
||||
"""
|
||||
inverse_value = self.bijector.inverse(value)
|
||||
unadjust_prob = self.distribution.log_prob(inverse_value)
|
||||
log_jacobian = self.bijector.inverse_log_jacobian(value)
|
||||
return unadjust_prob + log_jacobian
|
||||
|
||||
def _prob(self, value):
|
||||
return self.exp(self._log_prob(value))
|
||||
|
||||
def _sample(self, shape):
|
||||
org_sample = self.distribution.sample(shape)
|
||||
return self.bijector.forward(org_sample)
|
||||
|
||||
def _mean(self):
|
||||
"""
|
||||
Note:
|
||||
This function maybe overridden by derived class.
|
||||
"""
|
||||
return self.bijector.forward(self.distribution.mean())
|
|
@ -0,0 +1,105 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test cases for exp"""
|
||||
import numpy as np
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore.nn.probability.bijector as msb
|
||||
from mindspore import Tensor
|
||||
from mindspore import dtype
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
class Net(nn.Cell):
|
||||
"""
|
||||
Test class: forward pass of bijector.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.bijector = msb.Exp()
|
||||
|
||||
def construct(self, x_):
|
||||
forward = self.bijector.forward(x_)
|
||||
return forward
|
||||
|
||||
def test_forward():
|
||||
x = np.array([2.0, 3.0, 4.0, 5.0], dtype=np.float32)
|
||||
tx = Tensor(x, dtype=dtype.float32)
|
||||
forward = Net()
|
||||
ans = forward(tx)
|
||||
expected = np.exp(x)
|
||||
tol = 1e-5
|
||||
assert (np.abs(ans.asnumpy() - expected) < tol).all()
|
||||
|
||||
class Net1(nn.Cell):
|
||||
"""
|
||||
Test class: inverse pass of bijector.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(Net1, self).__init__()
|
||||
self.bijector = msb.Exp()
|
||||
|
||||
def construct(self, y_):
|
||||
inverse = self.bijector.inverse(y_)
|
||||
return inverse
|
||||
|
||||
def test_inverse():
|
||||
y = np.array([2.0, 3.0, 4.0, 5.0], dtype=np.float32)
|
||||
ty = Tensor(y, dtype=dtype.float32)
|
||||
inverse = Net1()
|
||||
ans = inverse(ty)
|
||||
expected = np.log(y)
|
||||
tol = 1e-6
|
||||
assert (np.abs(ans.asnumpy() - expected) < tol).all()
|
||||
|
||||
class Net2(nn.Cell):
|
||||
"""
|
||||
Test class: Forward Jacobian.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(Net2, self).__init__()
|
||||
self.bijector = msb.Exp()
|
||||
|
||||
def construct(self, x_):
|
||||
return self.bijector.forward_log_jacobian(x_)
|
||||
|
||||
def test_forward_jacobian():
|
||||
x = np.array([2.0, 3.0, 4.0, 5.0], dtype=np.float32)
|
||||
tx = Tensor(x, dtype=dtype.float32)
|
||||
forward_jacobian = Net2()
|
||||
ans = forward_jacobian(tx)
|
||||
expected = x
|
||||
tol = 1e-6
|
||||
assert (np.abs(ans.asnumpy() - expected) < tol).all()
|
||||
|
||||
class Net3(nn.Cell):
|
||||
"""
|
||||
Test class: Backward Jacobian.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(Net3, self).__init__()
|
||||
self.bijector = msb.Exp()
|
||||
|
||||
def construct(self, y_):
|
||||
return self.bijector.inverse_log_jacobian(y_)
|
||||
|
||||
def test_inverse_jacobian():
|
||||
y = np.array([2.0, 3.0, 4.0, 5.0], dtype=np.float32)
|
||||
ty = Tensor(y, dtype=dtype.float32)
|
||||
inverse_jacobian = Net3()
|
||||
ans = inverse_jacobian(ty)
|
||||
expected = -np.log(y)
|
||||
tol = 1e-6
|
||||
assert (np.abs(ans.asnumpy() - expected) < tol).all()
|
|
@ -0,0 +1,109 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test cases for powertransform"""
|
||||
import numpy as np
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore.nn.probability.bijector as msb
|
||||
from mindspore import Tensor
|
||||
from mindspore import dtype
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
class Net(nn.Cell):
|
||||
"""
|
||||
Test class: forward pass of bijector.
|
||||
"""
|
||||
def __init__(self, power):
|
||||
super(Net, self).__init__()
|
||||
self.bijector = msb.PowerTransform(power=power)
|
||||
|
||||
def construct(self, x_):
|
||||
forward = self.bijector.forward(x_)
|
||||
return forward
|
||||
|
||||
def test_forward():
|
||||
power = 2
|
||||
x = np.array([2.0, 3.0, 4.0, 5.0], dtype=np.float32)
|
||||
tx = Tensor(x, dtype=dtype.float32)
|
||||
forward = Net(power=power)
|
||||
ans = forward(tx)
|
||||
expected = np.exp(np.log1p(x * power) / power)
|
||||
tol = 1e-6
|
||||
assert (np.abs(ans.asnumpy() - expected) < tol).all()
|
||||
|
||||
class Net1(nn.Cell):
|
||||
"""
|
||||
Test class: inverse pass of bijector.
|
||||
"""
|
||||
def __init__(self, power):
|
||||
super(Net1, self).__init__()
|
||||
self.bijector = msb.PowerTransform(power=power)
|
||||
|
||||
def construct(self, y_):
|
||||
inverse = self.bijector.inverse(y_)
|
||||
return inverse
|
||||
|
||||
def test_inverse():
|
||||
power = 2
|
||||
y = np.array([2.0, 3.0, 4.0, 5.0], dtype=np.float32)
|
||||
ty = Tensor(y, dtype=dtype.float32)
|
||||
inverse = Net1(power=power)
|
||||
ans = inverse(ty)
|
||||
expected = np.expm1(np.log(y) * power) / power
|
||||
tol = 1e-6
|
||||
assert (np.abs(ans.asnumpy() - expected) < tol).all()
|
||||
|
||||
class Net2(nn.Cell):
|
||||
"""
|
||||
Test class: Forward Jacobian.
|
||||
"""
|
||||
def __init__(self, power):
|
||||
super(Net2, self).__init__()
|
||||
self.bijector = msb.PowerTransform(power=power)
|
||||
|
||||
def construct(self, x_):
|
||||
return self.bijector.forward_log_jacobian(x_)
|
||||
|
||||
def test_forward_jacobian():
|
||||
power = 2
|
||||
x = np.array([2.0, 3.0, 4.0, 5.0], dtype=np.float32)
|
||||
tx = Tensor(x, dtype=dtype.float32)
|
||||
forward_jacobian = Net2(power=power)
|
||||
ans = forward_jacobian(tx)
|
||||
expected = (1 / power - 1) * np.log1p(x * power)
|
||||
tol = 1e-6
|
||||
assert (np.abs(ans.asnumpy() - expected) < tol).all()
|
||||
|
||||
class Net3(nn.Cell):
|
||||
"""
|
||||
Test class: Backward Jacobian.
|
||||
"""
|
||||
def __init__(self, power):
|
||||
super(Net3, self).__init__()
|
||||
self.bijector = msb.PowerTransform(power=power)
|
||||
|
||||
def construct(self, y_):
|
||||
return self.bijector.inverse_log_jacobian(y_)
|
||||
|
||||
def test_inverse_jacobian():
|
||||
power = 2
|
||||
y = np.array([2.0, 3.0, 4.0, 5.0], dtype=np.float32)
|
||||
ty = Tensor(y, dtype=dtype.float32)
|
||||
inverse_jacobian = Net3(power=power)
|
||||
ans = inverse_jacobian(ty)
|
||||
expected = (power - 1) * np.log(y)
|
||||
tol = 1e-6
|
||||
assert (np.abs(ans.asnumpy() - expected) < tol).all()
|
|
@ -0,0 +1,71 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test cases for exp"""
|
||||
import mindspore.nn as nn
|
||||
import mindspore.nn.probability.bijector as msb
|
||||
from mindspore import Tensor
|
||||
from mindspore import dtype
|
||||
|
||||
def test_init():
|
||||
b = msb.Exp()
|
||||
assert isinstance(b, msb.Bijector)
|
||||
b = msb.Exp(1.0)
|
||||
assert isinstance(b, msb.Bijector)
|
||||
|
||||
class Net(nn.Cell):
|
||||
"""
|
||||
Test class: forward and inverse pass of bijector.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.b1 = msb.Exp()
|
||||
self.b2 = msb.Exp()
|
||||
|
||||
def construct(self, x_):
|
||||
forward = self.b1.forward(x_)
|
||||
inverse = self.b1.inverse(forward)
|
||||
return x_ - inverse
|
||||
|
||||
def test1():
|
||||
"""
|
||||
Test forward and inverse pass of exp bijector.
|
||||
"""
|
||||
net = Net()
|
||||
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
|
||||
ans = net(x)
|
||||
assert isinstance(ans, Tensor)
|
||||
|
||||
class Jacobian(nn.Cell):
|
||||
"""
|
||||
Test class: forward and inverse pass of bijector.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(Jacobian, self).__init__()
|
||||
self.b1 = msb.Exp()
|
||||
self.b2 = msb.Exp()
|
||||
|
||||
def construct(self, x_):
|
||||
ans1 = self.b1.forward_log_jacobian(x_)
|
||||
ans2 = self.b1.inverse_log_jacobian(x_)
|
||||
return ans1 + ans2
|
||||
|
||||
def test2():
|
||||
"""
|
||||
Test jacobians of exp bijector.
|
||||
"""
|
||||
net = Jacobian()
|
||||
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
|
||||
ans = net(x)
|
||||
assert isinstance(ans, Tensor)
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test cases for powertransform"""
|
||||
import mindspore.nn as nn
|
||||
import mindspore.nn.probability.bijector as msb
|
||||
from mindspore import Tensor
|
||||
from mindspore import dtype
|
||||
|
||||
def test_init():
|
||||
b = msb.PowerTransform()
|
||||
assert isinstance(b, msb.Bijector)
|
||||
b = msb.PowerTransform(1)
|
||||
assert isinstance(b, msb.Bijector)
|
||||
|
||||
class Net(nn.Cell):
|
||||
"""
|
||||
Test class: forward and inverse pass of bijector.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.b1 = msb.PowerTransform(power=0)
|
||||
self.b2 = msb.PowerTransform()
|
||||
|
||||
def construct(self, x_):
|
||||
ans1 = self.b1.inverse(self.b1.forward(x_))
|
||||
ans2 = self.b2.inverse(self.b2.forward(x_))
|
||||
return ans1 - ans2
|
||||
|
||||
def test1():
|
||||
"""
|
||||
Test forward and inverse pass of powertransform bijector.
|
||||
"""
|
||||
net = Net()
|
||||
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
|
||||
ans = net(x)
|
||||
assert isinstance(ans, Tensor)
|
||||
|
||||
class Jacobian(nn.Cell):
|
||||
"""
|
||||
Test class: forward and inverse pass of bijector.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(Jacobian, self).__init__()
|
||||
self.b1 = msb.PowerTransform(power=0)
|
||||
self.b2 = msb.PowerTransform()
|
||||
|
||||
def construct(self, x_):
|
||||
ans1 = self.b1.forward_log_jacobian(x_)
|
||||
ans2 = self.b2.forward_log_jacobian(x_)
|
||||
ans3 = self.b1.inverse_log_jacobian(x_)
|
||||
ans4 = self.b2.inverse_log_jacobian(x_)
|
||||
return ans1 - ans2 + ans3 - ans4
|
||||
|
||||
def test2():
|
||||
"""
|
||||
Test jacobians of powertransform bijector.
|
||||
"""
|
||||
net = Jacobian()
|
||||
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
|
||||
ans = net(x)
|
||||
assert isinstance(ans, Tensor)
|
Loading…
Reference in New Issue