forked from mindspore-Ecosystem/mindspore
Add custom op interface to replace expm1, log1p and log
This commit is contained in:
parent
29e21479a4
commit
8e0343830e
|
@ -17,6 +17,7 @@ from mindspore.ops import operations as P
|
||||||
from mindspore._checkparam import Validator as validator
|
from mindspore._checkparam import Validator as validator
|
||||||
from mindspore._checkparam import Rel
|
from mindspore._checkparam import Rel
|
||||||
from ..distribution._utils.utils import CheckTensor
|
from ..distribution._utils.utils import CheckTensor
|
||||||
|
from ..distribution._utils.custom_ops import log_by_step, log1p_by_step, expm1_by_step
|
||||||
from .bijector import Bijector
|
from .bijector import Bijector
|
||||||
|
|
||||||
class PowerTransform(Bijector):
|
class PowerTransform(Bijector):
|
||||||
|
@ -59,24 +60,12 @@ class PowerTransform(Bijector):
|
||||||
self._power = power
|
self._power = power
|
||||||
self.pow = P.Pow()
|
self.pow = P.Pow()
|
||||||
self.exp = P.Exp()
|
self.exp = P.Exp()
|
||||||
self.log = P.Log()
|
self.log = log_by_step
|
||||||
self.log1p = self._log1p_by_step
|
self.log1p = log1p_by_step
|
||||||
self.expm1 = self._expm1_by_step
|
self.expm1 = expm1_by_step
|
||||||
|
|
||||||
self.checktensor = CheckTensor()
|
self.checktensor = CheckTensor()
|
||||||
|
|
||||||
def _log1p_by_step(self, x):
|
|
||||||
"""
|
|
||||||
Log1p ops on GPU device or when device_target == GPU.
|
|
||||||
"""
|
|
||||||
return self.log(x + 1.0)
|
|
||||||
|
|
||||||
def _expm1_by_step(self, x):
|
|
||||||
"""
|
|
||||||
Expm1 ops on GPU device or when device_target == GPU.
|
|
||||||
"""
|
|
||||||
return self.exp(x) - 1.0
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def power(self):
|
def power(self):
|
||||||
return self._power
|
return self._power
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore._checkparam import Validator as validator
|
from mindspore._checkparam import Validator as validator
|
||||||
from ..distribution._utils.utils import cast_to_tensor, CheckTensor
|
from ..distribution._utils.utils import cast_to_tensor, CheckTensor
|
||||||
|
from ..distribution._utils.custom_ops import log_by_step
|
||||||
from .bijector import Bijector
|
from .bijector import Bijector
|
||||||
|
|
||||||
class ScalarAffine(Bijector):
|
class ScalarAffine(Bijector):
|
||||||
|
@ -66,7 +67,7 @@ class ScalarAffine(Bijector):
|
||||||
param=param)
|
param=param)
|
||||||
|
|
||||||
self.abs = P.Abs()
|
self.abs = P.Abs()
|
||||||
self.log = P.Log()
|
self.log = log_by_step
|
||||||
|
|
||||||
self.checktensor = CheckTensor()
|
self.checktensor = CheckTensor()
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,7 @@ from mindspore.common import dtype as mstype
|
||||||
from mindspore.nn.layer.activation import LogSigmoid
|
from mindspore.nn.layer.activation import LogSigmoid
|
||||||
from mindspore._checkparam import Validator as validator
|
from mindspore._checkparam import Validator as validator
|
||||||
from ..distribution._utils.utils import cast_to_tensor, CheckTensor
|
from ..distribution._utils.utils import cast_to_tensor, CheckTensor
|
||||||
|
from ..distribution._utils.custom_ops import log_by_step, expm1_by_step
|
||||||
from .bijector import Bijector
|
from .bijector import Bijector
|
||||||
|
|
||||||
class Softplus(Bijector):
|
class Softplus(Bijector):
|
||||||
|
@ -60,12 +61,12 @@ class Softplus(Bijector):
|
||||||
|
|
||||||
self.abs = P.Abs()
|
self.abs = P.Abs()
|
||||||
self.exp = P.Exp()
|
self.exp = P.Exp()
|
||||||
self.expm1 = self._expm1_by_step
|
self.log = log_by_step
|
||||||
|
self.expm1 = expm1_by_step
|
||||||
self.fill = P.Fill()
|
self.fill = P.Fill()
|
||||||
self.greater = P.Greater()
|
self.greater = P.Greater()
|
||||||
self.less = P.Less()
|
self.less = P.Less()
|
||||||
self.log_sigmoid = LogSigmoid()
|
self.log_sigmoid = LogSigmoid()
|
||||||
self.log = P.Log()
|
|
||||||
self.logicalor = P.LogicalOr()
|
self.logicalor = P.LogicalOr()
|
||||||
self.select = P.Select()
|
self.select = P.Select()
|
||||||
self.shape = P.Shape()
|
self.shape = P.Shape()
|
||||||
|
@ -76,12 +77,6 @@ class Softplus(Bijector):
|
||||||
self.checktensor = CheckTensor()
|
self.checktensor = CheckTensor()
|
||||||
self.threshold = np.log(np.finfo(np.float32).eps) + 1
|
self.threshold = np.log(np.finfo(np.float32).eps) + 1
|
||||||
|
|
||||||
def _expm1_by_step(self, x):
|
|
||||||
"""
|
|
||||||
Expm1 ops under GPU context.
|
|
||||||
"""
|
|
||||||
return self.exp(x) - 1.0
|
|
||||||
|
|
||||||
def _softplus(self, x):
|
def _softplus(self, x):
|
||||||
too_small = self.less(x, self.threshold)
|
too_small = self.less(x, self.threshold)
|
||||||
too_large = self.greater(x, -self.threshold)
|
too_large = self.greater(x, -self.threshold)
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
Distribution operation utility functions.
|
Distribution operation utility functions.
|
||||||
"""
|
"""
|
||||||
from .utils import *
|
from .utils import *
|
||||||
|
from .custom_ops import *
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'convert_to_batch',
|
'convert_to_batch',
|
||||||
|
@ -27,4 +28,7 @@ __all__ = [
|
||||||
'check_scalar_from_param',
|
'check_scalar_from_param',
|
||||||
'check_prob',
|
'check_prob',
|
||||||
'check_type',
|
'check_type',
|
||||||
|
'log_by_step',
|
||||||
|
'log1p_by_step',
|
||||||
|
'expm1_by_step',
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,48 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Utitly functions to help distribution class."""
|
||||||
|
import numpy as np
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
def log_by_step(input_x):
|
||||||
|
"""
|
||||||
|
Log op on Ascend is calculated as log(abs(x)).
|
||||||
|
Fix this with putting negative values as nan.
|
||||||
|
"""
|
||||||
|
select = P.Select()
|
||||||
|
log = P.Log()
|
||||||
|
lessequal = P.LessEqual()
|
||||||
|
fill = P.Fill()
|
||||||
|
dtype = P.DType()
|
||||||
|
shape = P.Shape()
|
||||||
|
|
||||||
|
nonpos_x = lessequal(input_x, 0.0)
|
||||||
|
log_x = log(input_x)
|
||||||
|
nan = fill(dtype(input_x), shape(input_x), np.nan)
|
||||||
|
result = select(nonpos_x, nan, log_x)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def log1p_by_step(x):
|
||||||
|
"""
|
||||||
|
Log1p ops on GPU device or when device_target == GPU.
|
||||||
|
"""
|
||||||
|
return log_by_step(x + 1.0)
|
||||||
|
|
||||||
|
def expm1_by_step(input_x):
|
||||||
|
"""
|
||||||
|
Expm1 ops under GPU context.
|
||||||
|
"""
|
||||||
|
exp = P.Exp()
|
||||||
|
return exp(input_x) - 1.0
|
|
@ -19,6 +19,7 @@ from mindspore.ops import composite as C
|
||||||
from .distribution import Distribution
|
from .distribution import Distribution
|
||||||
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error
|
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error
|
||||||
from ._utils.utils import CheckTensor, CheckTuple
|
from ._utils.utils import CheckTensor, CheckTuple
|
||||||
|
from ._utils.custom_ops import log_by_step
|
||||||
|
|
||||||
class Bernoulli(Distribution):
|
class Bernoulli(Distribution):
|
||||||
"""
|
"""
|
||||||
|
@ -116,7 +117,7 @@ class Bernoulli(Distribution):
|
||||||
self.exp = P.Exp()
|
self.exp = P.Exp()
|
||||||
self.floor = P.Floor()
|
self.floor = P.Floor()
|
||||||
self.fill = P.Fill()
|
self.fill = P.Fill()
|
||||||
self.log = P.Log()
|
self.log = log_by_step
|
||||||
self.less = P.Less()
|
self.less = P.Less()
|
||||||
self.shape = P.Shape()
|
self.shape = P.Shape()
|
||||||
self.select = P.Select()
|
self.select = P.Select()
|
||||||
|
|
|
@ -21,6 +21,7 @@ from .distribution import Distribution
|
||||||
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\
|
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\
|
||||||
raise_none_error
|
raise_none_error
|
||||||
from ._utils.utils import CheckTensor, CheckTuple
|
from ._utils.utils import CheckTensor, CheckTuple
|
||||||
|
from ._utils.custom_ops import log_by_step
|
||||||
|
|
||||||
class Exponential(Distribution):
|
class Exponential(Distribution):
|
||||||
"""
|
"""
|
||||||
|
@ -119,7 +120,7 @@ class Exponential(Distribution):
|
||||||
self.exp = P.Exp()
|
self.exp = P.Exp()
|
||||||
self.fill = P.Fill()
|
self.fill = P.Fill()
|
||||||
self.less = P.Less()
|
self.less = P.Less()
|
||||||
self.log = P.Log()
|
self.log = log_by_step
|
||||||
self.select = P.Select()
|
self.select = P.Select()
|
||||||
self.shape = P.Shape()
|
self.shape = P.Shape()
|
||||||
self.sqrt = P.Sqrt()
|
self.sqrt = P.Sqrt()
|
||||||
|
|
|
@ -21,6 +21,7 @@ from .distribution import Distribution
|
||||||
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\
|
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\
|
||||||
raise_none_error
|
raise_none_error
|
||||||
from ._utils.utils import CheckTensor, CheckTuple
|
from ._utils.utils import CheckTensor, CheckTuple
|
||||||
|
from ._utils.custom_ops import log_by_step
|
||||||
|
|
||||||
class Geometric(Distribution):
|
class Geometric(Distribution):
|
||||||
"""
|
"""
|
||||||
|
@ -122,7 +123,7 @@ class Geometric(Distribution):
|
||||||
self.floor = P.Floor()
|
self.floor = P.Floor()
|
||||||
self.issubclass = P.IsSubClass()
|
self.issubclass = P.IsSubClass()
|
||||||
self.less = P.Less()
|
self.less = P.Less()
|
||||||
self.log = P.Log()
|
self.log = log_by_step
|
||||||
self.pow = P.Pow()
|
self.pow = P.Pow()
|
||||||
self.select = P.Select()
|
self.select = P.Select()
|
||||||
self.shape = P.Shape()
|
self.shape = P.Shape()
|
||||||
|
|
|
@ -21,6 +21,7 @@ from .distribution import Distribution
|
||||||
from ._utils.utils import convert_to_batch, check_greater_zero, check_type, check_distribution_name,\
|
from ._utils.utils import convert_to_batch, check_greater_zero, check_type, check_distribution_name,\
|
||||||
raise_none_error
|
raise_none_error
|
||||||
from ._utils.utils import CheckTensor, CheckTuple
|
from ._utils.utils import CheckTensor, CheckTuple
|
||||||
|
from ._utils.custom_ops import log_by_step, expm1_by_step
|
||||||
|
|
||||||
class Normal(Distribution):
|
class Normal(Distribution):
|
||||||
"""
|
"""
|
||||||
|
@ -119,9 +120,9 @@ class Normal(Distribution):
|
||||||
self.const = P.ScalarToArray()
|
self.const = P.ScalarToArray()
|
||||||
self.erf = P.Erf()
|
self.erf = P.Erf()
|
||||||
self.exp = P.Exp()
|
self.exp = P.Exp()
|
||||||
self.expm1 = self._expm1_by_step
|
self.expm1 = expm1_by_step
|
||||||
self.fill = P.Fill()
|
self.fill = P.Fill()
|
||||||
self.log = P.Log()
|
self.log = log_by_step
|
||||||
self.shape = P.Shape()
|
self.shape = P.Shape()
|
||||||
self.sq = P.Square()
|
self.sq = P.Square()
|
||||||
self.sqrt = P.Sqrt()
|
self.sqrt = P.Sqrt()
|
||||||
|
@ -137,12 +138,6 @@ class Normal(Distribution):
|
||||||
str_info = f'batch_shape = {self._broadcast_shape}'
|
str_info = f'batch_shape = {self._broadcast_shape}'
|
||||||
return str_info
|
return str_info
|
||||||
|
|
||||||
def _expm1_by_step(self, x):
|
|
||||||
"""
|
|
||||||
Expm1 ops under GPU context.
|
|
||||||
"""
|
|
||||||
return self.exp(x) - 1.0
|
|
||||||
|
|
||||||
def _check_param(self, mean, sd):
|
def _check_param(self, mean, sd):
|
||||||
"""
|
"""
|
||||||
Check availablity of distribution specific args mean and sd.
|
Check availablity of distribution specific args mean and sd.
|
||||||
|
|
|
@ -19,6 +19,7 @@ from mindspore.common import dtype as mstype
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from .distribution import Distribution
|
from .distribution import Distribution
|
||||||
from ._utils.utils import check_type, raise_not_impl_error
|
from ._utils.utils import check_type, raise_not_impl_error
|
||||||
|
from ._utils.custom_ops import log_by_step
|
||||||
|
|
||||||
class TransformedDistribution(Distribution):
|
class TransformedDistribution(Distribution):
|
||||||
"""
|
"""
|
||||||
|
@ -56,7 +57,7 @@ class TransformedDistribution(Distribution):
|
||||||
self._distribution = distribution
|
self._distribution = distribution
|
||||||
self._is_linear_transformation = bijector.is_constant_jacobian
|
self._is_linear_transformation = bijector.is_constant_jacobian
|
||||||
self.exp = P.Exp()
|
self.exp = P.Exp()
|
||||||
self.log = P.Log()
|
self.log = log_by_step
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def bijector(self):
|
def bijector(self):
|
||||||
|
|
|
@ -20,6 +20,7 @@ from .distribution import Distribution
|
||||||
from ._utils.utils import convert_to_batch, check_greater, check_type, check_distribution_name,\
|
from ._utils.utils import convert_to_batch, check_greater, check_type, check_distribution_name,\
|
||||||
raise_none_error
|
raise_none_error
|
||||||
from ._utils.utils import CheckTensor, CheckTuple
|
from ._utils.utils import CheckTensor, CheckTuple
|
||||||
|
from ._utils.custom_ops import log_by_step
|
||||||
|
|
||||||
class Uniform(Distribution):
|
class Uniform(Distribution):
|
||||||
"""
|
"""
|
||||||
|
@ -121,7 +122,7 @@ class Uniform(Distribution):
|
||||||
self.fill = P.Fill()
|
self.fill = P.Fill()
|
||||||
self.less = P.Less()
|
self.less = P.Less()
|
||||||
self.lessequal = P.LessEqual()
|
self.lessequal = P.LessEqual()
|
||||||
self.log = P.Log()
|
self.log = log_by_step
|
||||||
self.logicaland = P.LogicalAnd()
|
self.logicaland = P.LogicalAnd()
|
||||||
self.select = P.Select()
|
self.select = P.Select()
|
||||||
self.shape = P.Shape()
|
self.shape = P.Shape()
|
||||||
|
|
Loading…
Reference in New Issue