From 239826515a532c8aaf3ebae4d4bf1c19e0d7be5b Mon Sep 17 00:00:00 2001 From: Xun Deng Date: Wed, 16 Sep 2020 15:23:03 -0400 Subject: [PATCH] modified dtype cast logic in custom_log and custom_exp and fixed dtype issues in softplus bijector --- mindspore/nn/probability/bijector/softplus.py | 6 +++--- .../nn/probability/distribution/_utils/custom_ops.py | 9 +++++++-- mindspore/nn/probability/distribution/bernoulli.py | 4 ++-- mindspore/nn/probability/distribution/geometric.py | 4 ++-- 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/mindspore/nn/probability/bijector/softplus.py b/mindspore/nn/probability/bijector/softplus.py index d95f940a637..5fa5ca724ba 100644 --- a/mindspore/nn/probability/bijector/softplus.py +++ b/mindspore/nn/probability/bijector/softplus.py @@ -15,7 +15,6 @@ """Softplus Bijector""" import numpy as np from mindspore.ops import operations as P -from mindspore.common import dtype as mstype from mindspore.nn.layer.activation import LogSigmoid from mindspore._checkparam import Validator as validator from ..distribution._utils.utils import cast_to_tensor @@ -71,6 +70,7 @@ class Softplus(Bijector): self.log = log_generic self.expm1 = expm1_generic self.abs = P.Abs() + self.dtypeop = P.DType() self.fill = P.Fill() self.greater = P.Greater() self.less = P.Less() @@ -90,7 +90,7 @@ class Softplus(Bijector): too_large = self.greater(x, -self.threshold) too_small_value = self.exp(x) too_large_value = x - ones = self.fill(mstype.float32, self.shape(x), 1.0) + ones = self.fill(self.dtypeop(x), self.shape(x), 1.0) too_small_or_too_large = self.logicalor(too_small, too_large) x = self.select(too_small_or_too_large, ones, x) y = self.log(self.exp(x) + 1.0) @@ -106,7 +106,7 @@ class Softplus(Bijector): too_large = self.greater(x, -self.threshold) too_small_value = self.log(x) too_large_value = x - ones = self.fill(mstype.float32, self.shape(x), 1.0) + ones = self.fill(self.dtypeop(x), self.shape(x), 1.0) too_small_or_too_large = self.logicalor(too_small, too_large) x = self.select(too_small_or_too_large, ones, x) y = x + self.log(self.abs(self.expm1(-x))) diff --git a/mindspore/nn/probability/distribution/_utils/custom_ops.py b/mindspore/nn/probability/distribution/_utils/custom_ops.py index ad5f9d33ace..bda3ae3eaa3 100644 --- a/mindspore/nn/probability/distribution/_utils/custom_ops.py +++ b/mindspore/nn/probability/distribution/_utils/custom_ops.py @@ -24,8 +24,11 @@ def exp_generic(input_x): """ exp = P.Exp() cast = P.Cast() + dtype = P.DType() + checktype = P.IsSubClass() - input_x = cast(input_x, mstype.float32) + if not checktype(dtype(input_x), mstype.float_): + input_x = cast(input_x, mstype.float32) return exp(input_x) @@ -51,8 +54,10 @@ def log_generic(input_x): dtype = P.DType() shape = P.Shape() select = P.Select() + checktype = P.IsSubClass() - input_x = cast(input_x, mstype.float32) + if not checktype(dtype(input_x), mstype.float_): + input_x = cast(input_x, mstype.float32) nan = fill(dtype(input_x), shape(input_x), np.nan) inf = fill(dtype(input_x), shape(input_x), np.inf) neg_x = less(input_x, 0.0) diff --git a/mindspore/nn/probability/distribution/bernoulli.py b/mindspore/nn/probability/distribution/bernoulli.py index 933dcbec13a..6543f41aae9 100644 --- a/mindspore/nn/probability/distribution/bernoulli.py +++ b/mindspore/nn/probability/distribution/bernoulli.py @@ -222,7 +222,7 @@ class Bernoulli(Distribution): pmf(k) = probs0 if k = 0; """ value = self._check_value(value, 'value') - value = self.cast(value, mstype.float32) + value = self.cast(value, self.parameter_type) probs1 = self._check_param_type(probs1) probs0 = 1.0 - probs1 return self.log(probs1) * value + self.log(probs0) * (1.0 - value) @@ -241,7 +241,7 @@ class Bernoulli(Distribution): cdf(k) = 1 if k >=1; """ value = self._check_value(value, 'value') - value = self.cast(value, mstype.float32) + value = self.cast(value, self.parameter_type) value = self.floor(value) probs1 = self._check_param_type(probs1) prob_type = self.dtypeop(probs1) diff --git a/mindspore/nn/probability/distribution/geometric.py b/mindspore/nn/probability/distribution/geometric.py index 3eaddcf7f99..e2d6255ede8 100644 --- a/mindspore/nn/probability/distribution/geometric.py +++ b/mindspore/nn/probability/distribution/geometric.py @@ -225,7 +225,7 @@ class Geometric(Distribution): pmf(k) = 0 if k < 0. """ value = self._check_value(value, 'value') - value = self.cast(value, mstype.float32) + value = self.cast(value, self.parameter_type) value = self.floor(value) probs1 = self._check_param_type(probs1) pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1)) @@ -247,7 +247,7 @@ class Geometric(Distribution): """ value = self._check_value(value, 'value') - value = self.cast(value, mstype.float32) + value = self.cast(value, self.parameter_type) value = self.floor(value) probs1 = self._check_param_type(probs1) probs0 = 1.0 - probs1