Fix errors in exp calculation logics

This commit is contained in:
peixu_ren 2020-08-21 22:57:25 -04:00
parent ac81886328
commit 03dac9b621
10 changed files with 61 additions and 36 deletions

View File

@ -17,7 +17,7 @@ from mindspore.ops import operations as P
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from ..distribution._utils.utils import CheckTensor
from ..distribution._utils.custom_ops import log_by_step, log1p_by_step, expm1_by_step
from ..distribution._utils.custom_ops import exp_by_step, expm1_by_step, log_by_step, log1p_by_step
from .bijector import Bijector
class PowerTransform(Bijector):
@ -59,10 +59,10 @@ class PowerTransform(Bijector):
validator.check_number("power", power, 0, Rel.GE, self.name)
self._power = power
self.pow = P.Pow()
self.exp = P.Exp()
self.exp = exp_by_step
self.expm1 = expm1_by_step
self.log = log_by_step
self.log1p = log1p_by_step
self.expm1 = expm1_by_step
self.checktensor = CheckTensor()

View File

@ -19,7 +19,7 @@ from mindspore.common import dtype as mstype
from mindspore.nn.layer.activation import LogSigmoid
from mindspore._checkparam import Validator as validator
from ..distribution._utils.utils import cast_to_tensor, CheckTensor
from ..distribution._utils.custom_ops import log_by_step, expm1_by_step
from ..distribution._utils.custom_ops import exp_by_step, expm1_by_step, log_by_step
from .bijector import Bijector
class Softplus(Bijector):
@ -59,10 +59,10 @@ class Softplus(Bijector):
super(Softplus, self).__init__(name=name, param=param)
self._sharpness = cast_to_tensor(sharpness)
self.abs = P.Abs()
self.exp = P.Exp()
self.exp = exp_by_step
self.log = log_by_step
self.expm1 = expm1_by_step
self.abs = P.Abs()
self.fill = P.Fill()
self.greater = P.Greater()
self.less = P.Less()

View File

@ -28,7 +28,8 @@ __all__ = [
'check_scalar_from_param',
'check_prob',
'check_type',
'exp_by_step',
'expm1_by_step',
'log_by_step',
'log1p_by_step',
'expm1_by_step',
]

View File

@ -17,10 +17,36 @@ import numpy as np
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
def exp_by_step(input_x):
"""
Log op on Ascend doesn't supprot int types.
Fix this with casting the type.
"""
exp = P.Exp()
cast = P.Cast()
dtype = P.DType()
checktype = P.IsSubClass()
if checktype(dtype(input_x), mstype.int_):
input_x = cast(input_x, mstype.float32)
elif checktype(dtype(input_x), mstype.float_):
pass
else:
return None
return exp(input_x)
def expm1_by_step(input_x):
"""
Expm1 ops under GPU context.
"""
return exp_by_step(input_x) - 1.0
def log_by_step(input_x):
"""
Log op on Ascend is calculated as log(abs(x)).
Fix this with putting negative values as nan.
And log op on Ascend doesn't supprot int types.
Fix this with casting the type.
"""
log = P.Log()
less = P.Less()
@ -30,8 +56,14 @@ def log_by_step(input_x):
dtype = P.DType()
shape = P.Shape()
select = P.Select()
checktype = P.IsSubClass()
input_x = cast(input_x, mstype.float32)
if checktype(dtype(input_x), mstype.int_):
input_x = cast(input_x, mstype.float32)
elif checktype(dtype(input_x), mstype.float_):
pass
else:
return None
nan = fill(dtype(input_x), shape(input_x), np.nan)
inf = fill(dtype(input_x), shape(input_x), np.inf)
neg_x = less(input_x, 0.0)
@ -45,10 +77,3 @@ def log1p_by_step(x):
Log1p ops on GPU device or when device_target == GPU.
"""
return log_by_step(x + 1.0)
def expm1_by_step(input_x):
"""
Expm1 ops under GPU context.
"""
exp = P.Exp()
return exp(input_x) - 1.0

View File

@ -18,7 +18,7 @@ from mindspore.ops import operations as P
from mindspore.ops import composite as C
from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error
from ._utils.custom_ops import log_by_step
from ._utils.custom_ops import exp_by_step, log_by_step
class Bernoulli(Distribution):
"""
@ -108,15 +108,15 @@ class Bernoulli(Distribution):
self._probs = probs
# ops needed for the class
self.exp = exp_by_step
self.log = log_by_step
self.squeeze = P.Squeeze(0)
self.cast = P.Cast()
self.const = P.ScalarToArray()
self.dtypeop = P.DType()
self.erf = P.Erf()
self.exp = P.Exp()
self.floor = P.Floor()
self.fill = P.Fill()
self.log = log_by_step
self.less = P.Less()
self.shape = P.Shape()
self.select = P.Select()

View File

@ -20,7 +20,7 @@ from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\
raise_none_error
from ._utils.custom_ops import log_by_step
from ._utils.custom_ops import exp_by_step, log_by_step
class Exponential(Distribution):
"""
@ -112,14 +112,14 @@ class Exponential(Distribution):
self.minval = np.finfo(np.float).tiny
# ops needed for the class
self.exp = exp_by_step
self.log = log_by_step
self.squeeze = P.Squeeze(0)
self.cast = P.Cast()
self.const = P.ScalarToArray()
self.dtypeop = P.DType()
self.exp = P.Exp()
self.fill = P.Fill()
self.less = P.Less()
self.log = log_by_step
self.select = P.Select()
self.shape = P.Shape()
self.sqrt = P.Sqrt()
@ -277,8 +277,8 @@ class Exponential(Distribution):
minval = self.const(self.minval)
maxval = self.const(1.0)
sample_uniform = self.uniform(sample_shape, minval, maxval, self.seed)
sample = -self.log(sample_uniform) / rate
value = self.cast(sample, self.dtype)
sample = self.log(sample_uniform) / rate
value = self.cast(-sample, self.dtype)
if origin_shape == ():
value = self.squeeze(value)
return value

View File

@ -20,7 +20,7 @@ from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\
raise_none_error
from ._utils.custom_ops import log_by_step
from ._utils.custom_ops import exp_by_step, log_by_step
class Geometric(Distribution):
"""
@ -113,16 +113,16 @@ class Geometric(Distribution):
self.minval = np.finfo(np.float).tiny
# ops needed for the class
self.exp = exp_by_step
self.log = log_by_step
self.squeeze = P.Squeeze(0)
self.cast = P.Cast()
self.const = P.ScalarToArray()
self.dtypeop = P.DType()
self.exp = P.Exp()
self.fill = P.Fill()
self.floor = P.Floor()
self.issubclass = P.IsSubClass()
self.less = P.Less()
self.log = log_by_step
self.pow = P.Pow()
self.select = P.Select()
self.shape = P.Shape()

View File

@ -20,7 +20,7 @@ from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import convert_to_batch, check_greater_zero, check_type, check_distribution_name,\
raise_none_error
from ._utils.custom_ops import log_by_step, expm1_by_step
from ._utils.custom_ops import exp_by_step, expm1_by_step, log_by_step
class Normal(Distribution):
"""
@ -114,14 +114,14 @@ class Normal(Distribution):
self._sd_value = sd
#ops needed for the class
self.exp = exp_by_step
self.expm1 = expm1_by_step
self.log = log_by_step
self.squeeze = P.Squeeze(0)
self.cast = P.Cast()
self.const = P.ScalarToArray()
self.erf = P.Erf()
self.exp = P.Exp()
self.expm1 = expm1_by_step
self.fill = P.Fill()
self.log = log_by_step
self.shape = P.Shape()
self.sq = P.Square()
self.sqrt = P.Sqrt()

View File

@ -13,13 +13,12 @@
# limitations under the License.
# ============================================================================
"""Transformed Distribution"""
from mindspore.ops import operations as P
from mindspore._checkparam import Validator as validator
from mindspore.common import dtype as mstype
import mindspore.nn as nn
from .distribution import Distribution
from ._utils.utils import check_type, raise_not_impl_error
from ._utils.custom_ops import log_by_step
from ._utils.custom_ops import exp_by_step, log_by_step
class TransformedDistribution(Distribution):
"""
@ -56,7 +55,7 @@ class TransformedDistribution(Distribution):
self._bijector = bijector
self._distribution = distribution
self._is_linear_transformation = bijector.is_constant_jacobian
self.exp = P.Exp()
self.exp = exp_by_step
self.log = log_by_step
@property

View File

@ -19,7 +19,7 @@ from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import convert_to_batch, check_greater, check_type, check_distribution_name,\
raise_none_error
from ._utils.custom_ops import log_by_step
from ._utils.custom_ops import exp_by_step, log_by_step
class Uniform(Distribution):
"""
@ -113,15 +113,15 @@ class Uniform(Distribution):
self._high = high
# ops needed for the class
self.exp = exp_by_step
self.log = log_by_step
self.squeeze = P.Squeeze(0)
self.cast = P.Cast()
self.const = P.ScalarToArray()
self.dtypeop = P.DType()
self.exp = P.Exp()
self.fill = P.Fill()
self.less = P.Less()
self.lessequal = P.LessEqual()
self.log = log_by_step
self.logicaland = P.LogicalAnd()
self.select = P.Select()
self.shape = P.Shape()