!8234 Remove expm1_generic and log1p_generic from PP utils

From: @peixu_ren
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-11-23 09:56:07 +08:00 committed by Gitee
commit 7f725b93a3
9 changed files with 17 additions and 25 deletions

View File

@ -15,7 +15,7 @@
"""Power Bijector"""
from mindspore.ops import operations as P
from ..distribution._utils.utils import check_greater_equal_zero
from ..distribution._utils.custom_ops import exp_generic, expm1_generic, log_generic, log1p_generic
from ..distribution._utils.custom_ops import exp_generic, log_generic
from .bijector import Bijector
@ -73,9 +73,9 @@ class PowerTransform(Bijector):
self.dtypeop = P.DType()
self.cast = P.Cast()
self.exp = exp_generic
self.expm1 = expm1_generic
self.expm1 = P.Expm1()
self.log = log_generic
self.log1p = log1p_generic
self.log1p = P.Log1p()
@property
def power(self):

View File

@ -16,7 +16,7 @@
import numpy as np
from mindspore.ops import operations as P
from mindspore.nn.layer.activation import LogSigmoid
from ..distribution._utils.custom_ops import exp_generic, expm1_generic, log_generic
from ..distribution._utils.custom_ops import exp_generic, log_generic
from .bijector import Bijector
@ -65,7 +65,7 @@ class Softplus(Bijector):
self.exp = exp_generic
self.log = log_generic
self.expm1 = expm1_generic
self.expm1 = P.Expm1()
self.abs = P.Abs()
self.dtypeop = P.DType()
self.cast = P.Cast()

View File

@ -25,9 +25,7 @@ __all__ = [
'check_greater_zero',
'check_prob',
'exp_generic',
'expm1_generic',
'log_generic',
'log1p_generic',
'broadcast_to',
'set_param_type',
'CheckTensor',

View File

@ -32,13 +32,6 @@ def exp_generic(input_x):
return exp(input_x)
def expm1_generic(input_x):
"""
Expm1 ops under GPU context.
"""
return exp_generic(input_x) - 1.0
def log_generic(input_x):
"""
Log op on Ascend is calculated as log(abs(x)).

View File

@ -22,7 +22,7 @@ import mindspore.nn.probability.bijector as msb
import mindspore.nn.probability.distribution as msd
from .transformed_distribution import TransformedDistribution
from ._utils.utils import check_distribution_name
from ._utils.custom_ops import exp_generic, expm1_generic, log_generic
from ._utils.custom_ops import exp_generic, log_generic
class Gumbel(TransformedDistribution):
"""
@ -120,7 +120,7 @@ class Gumbel(TransformedDistribution):
self.cast = P.Cast()
self.const = P.ScalarToArray()
self.exp = exp_generic
self.expm1 = expm1_generic
self.expm1 = P.Expm1()
self.fill = P.Fill()
self.lgamma = nn.LGamma()
self.log = log_generic

View File

@ -19,7 +19,7 @@ from mindspore.common import dtype as mstype
import mindspore.nn.probability.bijector as msb
import mindspore.nn.probability.distribution as msd
from ._utils.utils import check_distribution_name
from ._utils.custom_ops import exp_generic, expm1_generic, log_generic
from ._utils.custom_ops import exp_generic, log_generic
class LogNormal(msd.TransformedDistribution):
"""
@ -146,7 +146,7 @@ class LogNormal(msd.TransformedDistribution):
#ops needed for the class
self.exp = exp_generic
self.expm1 = expm1_generic
self.expm1 = P.Expm1()
self.log = log_generic
self.const = P.ScalarToArray()
self.erf = P.Erf()

View File

@ -20,7 +20,7 @@ from mindspore._checkparam import Validator
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import check_greater_zero
from ._utils.custom_ops import exp_generic, expm1_generic, log_generic, log1p_generic
from ._utils.custom_ops import exp_generic, log_generic
class Logistic(Distribution):
@ -124,11 +124,11 @@ class Logistic(Distribution):
self.const = P.ScalarToArray()
self.dtypeop = P.DType()
self.exp = exp_generic
self.expm1 = expm1_generic
self.expm1 = P.Expm1()
self.fill = P.Fill()
self.less = P.Less()
self.log = log_generic
self.log1p = log1p_generic
self.log1p = P.Log1p()
self.logicalor = P.LogicalOr()
self.erf = P.Erf()
self.greater = P.Greater()

View File

@ -20,7 +20,7 @@ from mindspore._checkparam import Validator
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import check_greater_zero, check_distribution_name
from ._utils.custom_ops import exp_generic, expm1_generic, log_generic
from ._utils.custom_ops import exp_generic, log_generic
class Normal(Distribution):
@ -137,7 +137,7 @@ class Normal(Distribution):
# ops needed for the class
self.exp = exp_generic
self.expm1 = expm1_generic
self.expm1 = P.Expm1()
self.log = log_generic
self.erf = P.Erf()
self.squeeze = P.Squeeze(0)

View File

@ -175,15 +175,16 @@ class LogNormalBasics(nn.Cell):
def construct(self):
mean = self.n.mean()
sd = self.n.sd()
mode = self.n.mode()
entropy = self.n.entropy()
return mean + sd + mode + entropy
return mean + mode + entropy
def test_bascis():
"""
Test mean/sd/mode/entropy functionality of LogNormal.
"""
from mindspore import context
context.set_context(device_target="Ascend")
net = LogNormalBasics()
ans = net()
assert isinstance(ans, Tensor)