diff --git a/mindspore/nn/probability/bijector/power_transform.py b/mindspore/nn/probability/bijector/power_transform.py index 1bf747c08a1..c6e2b9a635e 100644 --- a/mindspore/nn/probability/bijector/power_transform.py +++ b/mindspore/nn/probability/bijector/power_transform.py @@ -17,7 +17,7 @@ from mindspore.ops import operations as P from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel from ..distribution._utils.utils import CheckTensor -from ..distribution._utils.custom_ops import log_by_step, log1p_by_step, expm1_by_step +from ..distribution._utils.custom_ops import exp_by_step, expm1_by_step, log_by_step, log1p_by_step from .bijector import Bijector class PowerTransform(Bijector): @@ -59,10 +59,10 @@ class PowerTransform(Bijector): validator.check_number("power", power, 0, Rel.GE, self.name) self._power = power self.pow = P.Pow() - self.exp = P.Exp() + self.exp = exp_by_step + self.expm1 = expm1_by_step self.log = log_by_step self.log1p = log1p_by_step - self.expm1 = expm1_by_step self.checktensor = CheckTensor() diff --git a/mindspore/nn/probability/bijector/softplus.py b/mindspore/nn/probability/bijector/softplus.py index 69ea2d8d05f..9c0fc4e5f8a 100644 --- a/mindspore/nn/probability/bijector/softplus.py +++ b/mindspore/nn/probability/bijector/softplus.py @@ -19,7 +19,7 @@ from mindspore.common import dtype as mstype from mindspore.nn.layer.activation import LogSigmoid from mindspore._checkparam import Validator as validator from ..distribution._utils.utils import cast_to_tensor, CheckTensor -from ..distribution._utils.custom_ops import log_by_step, expm1_by_step +from ..distribution._utils.custom_ops import exp_by_step, expm1_by_step, log_by_step from .bijector import Bijector class Softplus(Bijector): @@ -59,10 +59,10 @@ class Softplus(Bijector): super(Softplus, self).__init__(name=name, param=param) self._sharpness = cast_to_tensor(sharpness) - self.abs = P.Abs() - self.exp = P.Exp() + self.exp = exp_by_step self.log = log_by_step self.expm1 = expm1_by_step + self.abs = P.Abs() self.fill = P.Fill() self.greater = P.Greater() self.less = P.Less() diff --git a/mindspore/nn/probability/distribution/_utils/__init__.py b/mindspore/nn/probability/distribution/_utils/__init__.py index 07ebb623b54..a2ea9e8f8bc 100644 --- a/mindspore/nn/probability/distribution/_utils/__init__.py +++ b/mindspore/nn/probability/distribution/_utils/__init__.py @@ -28,7 +28,8 @@ __all__ = [ 'check_scalar_from_param', 'check_prob', 'check_type', + 'exp_by_step', + 'expm1_by_step', 'log_by_step', 'log1p_by_step', - 'expm1_by_step', ] diff --git a/mindspore/nn/probability/distribution/_utils/custom_ops.py b/mindspore/nn/probability/distribution/_utils/custom_ops.py index e8acac6a07f..b81f5d186ab 100644 --- a/mindspore/nn/probability/distribution/_utils/custom_ops.py +++ b/mindspore/nn/probability/distribution/_utils/custom_ops.py @@ -17,10 +17,36 @@ import numpy as np from mindspore.ops import operations as P from mindspore.common import dtype as mstype +def exp_by_step(input_x): + """ + Log op on Ascend doesn't supprot int types. + Fix this with casting the type. + """ + exp = P.Exp() + cast = P.Cast() + dtype = P.DType() + checktype = P.IsSubClass() + + if checktype(dtype(input_x), mstype.int_): + input_x = cast(input_x, mstype.float32) + elif checktype(dtype(input_x), mstype.float_): + pass + else: + return None + return exp(input_x) + +def expm1_by_step(input_x): + """ + Expm1 ops under GPU context. + """ + return exp_by_step(input_x) - 1.0 + def log_by_step(input_x): """ Log op on Ascend is calculated as log(abs(x)). Fix this with putting negative values as nan. + And log op on Ascend doesn't supprot int types. + Fix this with casting the type. """ log = P.Log() less = P.Less() @@ -30,8 +56,14 @@ def log_by_step(input_x): dtype = P.DType() shape = P.Shape() select = P.Select() + checktype = P.IsSubClass() - input_x = cast(input_x, mstype.float32) + if checktype(dtype(input_x), mstype.int_): + input_x = cast(input_x, mstype.float32) + elif checktype(dtype(input_x), mstype.float_): + pass + else: + return None nan = fill(dtype(input_x), shape(input_x), np.nan) inf = fill(dtype(input_x), shape(input_x), np.inf) neg_x = less(input_x, 0.0) @@ -45,10 +77,3 @@ def log1p_by_step(x): Log1p ops on GPU device or when device_target == GPU. """ return log_by_step(x + 1.0) - -def expm1_by_step(input_x): - """ - Expm1 ops under GPU context. - """ - exp = P.Exp() - return exp(input_x) - 1.0 diff --git a/mindspore/nn/probability/distribution/bernoulli.py b/mindspore/nn/probability/distribution/bernoulli.py index 2ef9ed83215..ee673833f32 100644 --- a/mindspore/nn/probability/distribution/bernoulli.py +++ b/mindspore/nn/probability/distribution/bernoulli.py @@ -18,7 +18,7 @@ from mindspore.ops import operations as P from mindspore.ops import composite as C from .distribution import Distribution from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error -from ._utils.custom_ops import log_by_step +from ._utils.custom_ops import exp_by_step, log_by_step class Bernoulli(Distribution): """ @@ -108,15 +108,15 @@ class Bernoulli(Distribution): self._probs = probs # ops needed for the class + self.exp = exp_by_step + self.log = log_by_step self.squeeze = P.Squeeze(0) self.cast = P.Cast() self.const = P.ScalarToArray() self.dtypeop = P.DType() self.erf = P.Erf() - self.exp = P.Exp() self.floor = P.Floor() self.fill = P.Fill() - self.log = log_by_step self.less = P.Less() self.shape = P.Shape() self.select = P.Select() diff --git a/mindspore/nn/probability/distribution/exponential.py b/mindspore/nn/probability/distribution/exponential.py index 1311a43c585..f9e9dca8f5a 100644 --- a/mindspore/nn/probability/distribution/exponential.py +++ b/mindspore/nn/probability/distribution/exponential.py @@ -20,7 +20,7 @@ from mindspore.common import dtype as mstype from .distribution import Distribution from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\ raise_none_error -from ._utils.custom_ops import log_by_step +from ._utils.custom_ops import exp_by_step, log_by_step class Exponential(Distribution): """ @@ -112,14 +112,14 @@ class Exponential(Distribution): self.minval = np.finfo(np.float).tiny # ops needed for the class + self.exp = exp_by_step + self.log = log_by_step self.squeeze = P.Squeeze(0) self.cast = P.Cast() self.const = P.ScalarToArray() self.dtypeop = P.DType() - self.exp = P.Exp() self.fill = P.Fill() self.less = P.Less() - self.log = log_by_step self.select = P.Select() self.shape = P.Shape() self.sqrt = P.Sqrt() @@ -277,8 +277,8 @@ class Exponential(Distribution): minval = self.const(self.minval) maxval = self.const(1.0) sample_uniform = self.uniform(sample_shape, minval, maxval, self.seed) - sample = -self.log(sample_uniform) / rate - value = self.cast(sample, self.dtype) + sample = self.log(sample_uniform) / rate + value = self.cast(-sample, self.dtype) if origin_shape == (): value = self.squeeze(value) return value diff --git a/mindspore/nn/probability/distribution/geometric.py b/mindspore/nn/probability/distribution/geometric.py index 8065af53a57..19531aad44b 100644 --- a/mindspore/nn/probability/distribution/geometric.py +++ b/mindspore/nn/probability/distribution/geometric.py @@ -20,7 +20,7 @@ from mindspore.common import dtype as mstype from .distribution import Distribution from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\ raise_none_error -from ._utils.custom_ops import log_by_step +from ._utils.custom_ops import exp_by_step, log_by_step class Geometric(Distribution): """ @@ -113,16 +113,16 @@ class Geometric(Distribution): self.minval = np.finfo(np.float).tiny # ops needed for the class + self.exp = exp_by_step + self.log = log_by_step self.squeeze = P.Squeeze(0) self.cast = P.Cast() self.const = P.ScalarToArray() self.dtypeop = P.DType() - self.exp = P.Exp() self.fill = P.Fill() self.floor = P.Floor() self.issubclass = P.IsSubClass() self.less = P.Less() - self.log = log_by_step self.pow = P.Pow() self.select = P.Select() self.shape = P.Shape() diff --git a/mindspore/nn/probability/distribution/normal.py b/mindspore/nn/probability/distribution/normal.py index 9993c5ec093..666394e14d9 100644 --- a/mindspore/nn/probability/distribution/normal.py +++ b/mindspore/nn/probability/distribution/normal.py @@ -20,7 +20,7 @@ from mindspore.common import dtype as mstype from .distribution import Distribution from ._utils.utils import convert_to_batch, check_greater_zero, check_type, check_distribution_name,\ raise_none_error -from ._utils.custom_ops import log_by_step, expm1_by_step +from ._utils.custom_ops import exp_by_step, expm1_by_step, log_by_step class Normal(Distribution): """ @@ -114,14 +114,14 @@ class Normal(Distribution): self._sd_value = sd #ops needed for the class + self.exp = exp_by_step + self.expm1 = expm1_by_step + self.log = log_by_step self.squeeze = P.Squeeze(0) self.cast = P.Cast() self.const = P.ScalarToArray() self.erf = P.Erf() - self.exp = P.Exp() - self.expm1 = expm1_by_step self.fill = P.Fill() - self.log = log_by_step self.shape = P.Shape() self.sq = P.Square() self.sqrt = P.Sqrt() diff --git a/mindspore/nn/probability/distribution/transformed_distribution.py b/mindspore/nn/probability/distribution/transformed_distribution.py index 9ca9f6bdf13..cee32c820eb 100644 --- a/mindspore/nn/probability/distribution/transformed_distribution.py +++ b/mindspore/nn/probability/distribution/transformed_distribution.py @@ -13,13 +13,12 @@ # limitations under the License. # ============================================================================ """Transformed Distribution""" -from mindspore.ops import operations as P from mindspore._checkparam import Validator as validator from mindspore.common import dtype as mstype import mindspore.nn as nn from .distribution import Distribution from ._utils.utils import check_type, raise_not_impl_error -from ._utils.custom_ops import log_by_step +from ._utils.custom_ops import exp_by_step, log_by_step class TransformedDistribution(Distribution): """ @@ -56,7 +55,7 @@ class TransformedDistribution(Distribution): self._bijector = bijector self._distribution = distribution self._is_linear_transformation = bijector.is_constant_jacobian - self.exp = P.Exp() + self.exp = exp_by_step self.log = log_by_step @property diff --git a/mindspore/nn/probability/distribution/uniform.py b/mindspore/nn/probability/distribution/uniform.py index d5d3aa6f34f..4e66fe50551 100644 --- a/mindspore/nn/probability/distribution/uniform.py +++ b/mindspore/nn/probability/distribution/uniform.py @@ -19,7 +19,7 @@ from mindspore.common import dtype as mstype from .distribution import Distribution from ._utils.utils import convert_to_batch, check_greater, check_type, check_distribution_name,\ raise_none_error -from ._utils.custom_ops import log_by_step +from ._utils.custom_ops import exp_by_step, log_by_step class Uniform(Distribution): """ @@ -113,15 +113,15 @@ class Uniform(Distribution): self._high = high # ops needed for the class + self.exp = exp_by_step + self.log = log_by_step self.squeeze = P.Squeeze(0) self.cast = P.Cast() self.const = P.ScalarToArray() self.dtypeop = P.DType() - self.exp = P.Exp() self.fill = P.Fill() self.less = P.Less() self.lessequal = P.LessEqual() - self.log = log_by_step self.logicaland = P.LogicalAnd() self.select = P.Select() self.shape = P.Shape()