forked from mindspore-Ecosystem/mindspore
!4973 Fix errors in exp calculation logics
Merge pull request !4973 from peixu_ren/custom_pp_ops
This commit is contained in:
commit
add52da73e
|
@ -17,7 +17,7 @@ from mindspore.ops import operations as P
|
|||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
from ..distribution._utils.utils import CheckTensor
|
||||
from ..distribution._utils.custom_ops import log_by_step, log1p_by_step, expm1_by_step
|
||||
from ..distribution._utils.custom_ops import exp_by_step, expm1_by_step, log_by_step, log1p_by_step
|
||||
from .bijector import Bijector
|
||||
|
||||
class PowerTransform(Bijector):
|
||||
|
@ -59,10 +59,10 @@ class PowerTransform(Bijector):
|
|||
validator.check_number("power", power, 0, Rel.GE, self.name)
|
||||
self._power = power
|
||||
self.pow = P.Pow()
|
||||
self.exp = P.Exp()
|
||||
self.exp = exp_by_step
|
||||
self.expm1 = expm1_by_step
|
||||
self.log = log_by_step
|
||||
self.log1p = log1p_by_step
|
||||
self.expm1 = expm1_by_step
|
||||
|
||||
self.checktensor = CheckTensor()
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ from mindspore.common import dtype as mstype
|
|||
from mindspore.nn.layer.activation import LogSigmoid
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from ..distribution._utils.utils import cast_to_tensor, CheckTensor
|
||||
from ..distribution._utils.custom_ops import log_by_step, expm1_by_step
|
||||
from ..distribution._utils.custom_ops import exp_by_step, expm1_by_step, log_by_step
|
||||
from .bijector import Bijector
|
||||
|
||||
class Softplus(Bijector):
|
||||
|
@ -59,10 +59,10 @@ class Softplus(Bijector):
|
|||
super(Softplus, self).__init__(name=name, param=param)
|
||||
self._sharpness = cast_to_tensor(sharpness)
|
||||
|
||||
self.abs = P.Abs()
|
||||
self.exp = P.Exp()
|
||||
self.exp = exp_by_step
|
||||
self.log = log_by_step
|
||||
self.expm1 = expm1_by_step
|
||||
self.abs = P.Abs()
|
||||
self.fill = P.Fill()
|
||||
self.greater = P.Greater()
|
||||
self.less = P.Less()
|
||||
|
|
|
@ -28,7 +28,8 @@ __all__ = [
|
|||
'check_scalar_from_param',
|
||||
'check_prob',
|
||||
'check_type',
|
||||
'exp_by_step',
|
||||
'expm1_by_step',
|
||||
'log_by_step',
|
||||
'log1p_by_step',
|
||||
'expm1_by_step',
|
||||
]
|
||||
|
|
|
@ -17,10 +17,36 @@ import numpy as np
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
def exp_by_step(input_x):
|
||||
"""
|
||||
Log op on Ascend doesn't supprot int types.
|
||||
Fix this with casting the type.
|
||||
"""
|
||||
exp = P.Exp()
|
||||
cast = P.Cast()
|
||||
dtype = P.DType()
|
||||
checktype = P.IsSubClass()
|
||||
|
||||
if checktype(dtype(input_x), mstype.int_):
|
||||
input_x = cast(input_x, mstype.float32)
|
||||
elif checktype(dtype(input_x), mstype.float_):
|
||||
pass
|
||||
else:
|
||||
return None
|
||||
return exp(input_x)
|
||||
|
||||
def expm1_by_step(input_x):
|
||||
"""
|
||||
Expm1 ops under GPU context.
|
||||
"""
|
||||
return exp_by_step(input_x) - 1.0
|
||||
|
||||
def log_by_step(input_x):
|
||||
"""
|
||||
Log op on Ascend is calculated as log(abs(x)).
|
||||
Fix this with putting negative values as nan.
|
||||
And log op on Ascend doesn't supprot int types.
|
||||
Fix this with casting the type.
|
||||
"""
|
||||
log = P.Log()
|
||||
less = P.Less()
|
||||
|
@ -30,8 +56,14 @@ def log_by_step(input_x):
|
|||
dtype = P.DType()
|
||||
shape = P.Shape()
|
||||
select = P.Select()
|
||||
checktype = P.IsSubClass()
|
||||
|
||||
input_x = cast(input_x, mstype.float32)
|
||||
if checktype(dtype(input_x), mstype.int_):
|
||||
input_x = cast(input_x, mstype.float32)
|
||||
elif checktype(dtype(input_x), mstype.float_):
|
||||
pass
|
||||
else:
|
||||
return None
|
||||
nan = fill(dtype(input_x), shape(input_x), np.nan)
|
||||
inf = fill(dtype(input_x), shape(input_x), np.inf)
|
||||
neg_x = less(input_x, 0.0)
|
||||
|
@ -45,10 +77,3 @@ def log1p_by_step(x):
|
|||
Log1p ops on GPU device or when device_target == GPU.
|
||||
"""
|
||||
return log_by_step(x + 1.0)
|
||||
|
||||
def expm1_by_step(input_x):
|
||||
"""
|
||||
Expm1 ops under GPU context.
|
||||
"""
|
||||
exp = P.Exp()
|
||||
return exp(input_x) - 1.0
|
||||
|
|
|
@ -18,7 +18,7 @@ from mindspore.ops import operations as P
|
|||
from mindspore.ops import composite as C
|
||||
from .distribution import Distribution
|
||||
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error
|
||||
from ._utils.custom_ops import log_by_step
|
||||
from ._utils.custom_ops import exp_by_step, log_by_step
|
||||
|
||||
class Bernoulli(Distribution):
|
||||
"""
|
||||
|
@ -108,15 +108,15 @@ class Bernoulli(Distribution):
|
|||
self._probs = probs
|
||||
|
||||
# ops needed for the class
|
||||
self.exp = exp_by_step
|
||||
self.log = log_by_step
|
||||
self.squeeze = P.Squeeze(0)
|
||||
self.cast = P.Cast()
|
||||
self.const = P.ScalarToArray()
|
||||
self.dtypeop = P.DType()
|
||||
self.erf = P.Erf()
|
||||
self.exp = P.Exp()
|
||||
self.floor = P.Floor()
|
||||
self.fill = P.Fill()
|
||||
self.log = log_by_step
|
||||
self.less = P.Less()
|
||||
self.shape = P.Shape()
|
||||
self.select = P.Select()
|
||||
|
|
|
@ -20,7 +20,7 @@ from mindspore.common import dtype as mstype
|
|||
from .distribution import Distribution
|
||||
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\
|
||||
raise_none_error
|
||||
from ._utils.custom_ops import log_by_step
|
||||
from ._utils.custom_ops import exp_by_step, log_by_step
|
||||
|
||||
class Exponential(Distribution):
|
||||
"""
|
||||
|
@ -112,14 +112,14 @@ class Exponential(Distribution):
|
|||
self.minval = np.finfo(np.float).tiny
|
||||
|
||||
# ops needed for the class
|
||||
self.exp = exp_by_step
|
||||
self.log = log_by_step
|
||||
self.squeeze = P.Squeeze(0)
|
||||
self.cast = P.Cast()
|
||||
self.const = P.ScalarToArray()
|
||||
self.dtypeop = P.DType()
|
||||
self.exp = P.Exp()
|
||||
self.fill = P.Fill()
|
||||
self.less = P.Less()
|
||||
self.log = log_by_step
|
||||
self.select = P.Select()
|
||||
self.shape = P.Shape()
|
||||
self.sqrt = P.Sqrt()
|
||||
|
@ -277,8 +277,8 @@ class Exponential(Distribution):
|
|||
minval = self.const(self.minval)
|
||||
maxval = self.const(1.0)
|
||||
sample_uniform = self.uniform(sample_shape, minval, maxval, self.seed)
|
||||
sample = -self.log(sample_uniform) / rate
|
||||
value = self.cast(sample, self.dtype)
|
||||
sample = self.log(sample_uniform) / rate
|
||||
value = self.cast(-sample, self.dtype)
|
||||
if origin_shape == ():
|
||||
value = self.squeeze(value)
|
||||
return value
|
||||
|
|
|
@ -20,7 +20,7 @@ from mindspore.common import dtype as mstype
|
|||
from .distribution import Distribution
|
||||
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\
|
||||
raise_none_error
|
||||
from ._utils.custom_ops import log_by_step
|
||||
from ._utils.custom_ops import exp_by_step, log_by_step
|
||||
|
||||
class Geometric(Distribution):
|
||||
"""
|
||||
|
@ -113,16 +113,16 @@ class Geometric(Distribution):
|
|||
self.minval = np.finfo(np.float).tiny
|
||||
|
||||
# ops needed for the class
|
||||
self.exp = exp_by_step
|
||||
self.log = log_by_step
|
||||
self.squeeze = P.Squeeze(0)
|
||||
self.cast = P.Cast()
|
||||
self.const = P.ScalarToArray()
|
||||
self.dtypeop = P.DType()
|
||||
self.exp = P.Exp()
|
||||
self.fill = P.Fill()
|
||||
self.floor = P.Floor()
|
||||
self.issubclass = P.IsSubClass()
|
||||
self.less = P.Less()
|
||||
self.log = log_by_step
|
||||
self.pow = P.Pow()
|
||||
self.select = P.Select()
|
||||
self.shape = P.Shape()
|
||||
|
|
|
@ -20,7 +20,7 @@ from mindspore.common import dtype as mstype
|
|||
from .distribution import Distribution
|
||||
from ._utils.utils import convert_to_batch, check_greater_zero, check_type, check_distribution_name,\
|
||||
raise_none_error
|
||||
from ._utils.custom_ops import log_by_step, expm1_by_step
|
||||
from ._utils.custom_ops import exp_by_step, expm1_by_step, log_by_step
|
||||
|
||||
class Normal(Distribution):
|
||||
"""
|
||||
|
@ -114,14 +114,14 @@ class Normal(Distribution):
|
|||
self._sd_value = sd
|
||||
|
||||
#ops needed for the class
|
||||
self.exp = exp_by_step
|
||||
self.expm1 = expm1_by_step
|
||||
self.log = log_by_step
|
||||
self.squeeze = P.Squeeze(0)
|
||||
self.cast = P.Cast()
|
||||
self.const = P.ScalarToArray()
|
||||
self.erf = P.Erf()
|
||||
self.exp = P.Exp()
|
||||
self.expm1 = expm1_by_step
|
||||
self.fill = P.Fill()
|
||||
self.log = log_by_step
|
||||
self.shape = P.Shape()
|
||||
self.sq = P.Square()
|
||||
self.sqrt = P.Sqrt()
|
||||
|
|
|
@ -13,13 +13,12 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Transformed Distribution"""
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore.common import dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
from .distribution import Distribution
|
||||
from ._utils.utils import check_type, raise_not_impl_error
|
||||
from ._utils.custom_ops import log_by_step
|
||||
from ._utils.custom_ops import exp_by_step, log_by_step
|
||||
|
||||
class TransformedDistribution(Distribution):
|
||||
"""
|
||||
|
@ -56,7 +55,7 @@ class TransformedDistribution(Distribution):
|
|||
self._bijector = bijector
|
||||
self._distribution = distribution
|
||||
self._is_linear_transformation = bijector.is_constant_jacobian
|
||||
self.exp = P.Exp()
|
||||
self.exp = exp_by_step
|
||||
self.log = log_by_step
|
||||
|
||||
@property
|
||||
|
|
|
@ -19,7 +19,7 @@ from mindspore.common import dtype as mstype
|
|||
from .distribution import Distribution
|
||||
from ._utils.utils import convert_to_batch, check_greater, check_type, check_distribution_name,\
|
||||
raise_none_error
|
||||
from ._utils.custom_ops import log_by_step
|
||||
from ._utils.custom_ops import exp_by_step, log_by_step
|
||||
|
||||
class Uniform(Distribution):
|
||||
"""
|
||||
|
@ -113,15 +113,15 @@ class Uniform(Distribution):
|
|||
self._high = high
|
||||
|
||||
# ops needed for the class
|
||||
self.exp = exp_by_step
|
||||
self.log = log_by_step
|
||||
self.squeeze = P.Squeeze(0)
|
||||
self.cast = P.Cast()
|
||||
self.const = P.ScalarToArray()
|
||||
self.dtypeop = P.DType()
|
||||
self.exp = P.Exp()
|
||||
self.fill = P.Fill()
|
||||
self.less = P.Less()
|
||||
self.lessequal = P.LessEqual()
|
||||
self.log = log_by_step
|
||||
self.logicaland = P.LogicalAnd()
|
||||
self.select = P.Select()
|
||||
self.shape = P.Shape()
|
||||
|
|
Loading…
Reference in New Issue