modified dtype cast logic in custom_log and custom_exp and fixed dtype issues in softplus bijector

This commit is contained in:
Xun Deng 2020-09-16 15:23:03 -04:00
parent 1256737a7c
commit 239826515a
4 changed files with 14 additions and 9 deletions

View File

@ -15,7 +15,6 @@
"""Softplus Bijector""" """Softplus Bijector"""
import numpy as np import numpy as np
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
from mindspore.nn.layer.activation import LogSigmoid from mindspore.nn.layer.activation import LogSigmoid
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from ..distribution._utils.utils import cast_to_tensor from ..distribution._utils.utils import cast_to_tensor
@ -71,6 +70,7 @@ class Softplus(Bijector):
self.log = log_generic self.log = log_generic
self.expm1 = expm1_generic self.expm1 = expm1_generic
self.abs = P.Abs() self.abs = P.Abs()
self.dtypeop = P.DType()
self.fill = P.Fill() self.fill = P.Fill()
self.greater = P.Greater() self.greater = P.Greater()
self.less = P.Less() self.less = P.Less()
@ -90,7 +90,7 @@ class Softplus(Bijector):
too_large = self.greater(x, -self.threshold) too_large = self.greater(x, -self.threshold)
too_small_value = self.exp(x) too_small_value = self.exp(x)
too_large_value = x too_large_value = x
ones = self.fill(mstype.float32, self.shape(x), 1.0) ones = self.fill(self.dtypeop(x), self.shape(x), 1.0)
too_small_or_too_large = self.logicalor(too_small, too_large) too_small_or_too_large = self.logicalor(too_small, too_large)
x = self.select(too_small_or_too_large, ones, x) x = self.select(too_small_or_too_large, ones, x)
y = self.log(self.exp(x) + 1.0) y = self.log(self.exp(x) + 1.0)
@ -106,7 +106,7 @@ class Softplus(Bijector):
too_large = self.greater(x, -self.threshold) too_large = self.greater(x, -self.threshold)
too_small_value = self.log(x) too_small_value = self.log(x)
too_large_value = x too_large_value = x
ones = self.fill(mstype.float32, self.shape(x), 1.0) ones = self.fill(self.dtypeop(x), self.shape(x), 1.0)
too_small_or_too_large = self.logicalor(too_small, too_large) too_small_or_too_large = self.logicalor(too_small, too_large)
x = self.select(too_small_or_too_large, ones, x) x = self.select(too_small_or_too_large, ones, x)
y = x + self.log(self.abs(self.expm1(-x))) y = x + self.log(self.abs(self.expm1(-x)))

View File

@ -24,8 +24,11 @@ def exp_generic(input_x):
""" """
exp = P.Exp() exp = P.Exp()
cast = P.Cast() cast = P.Cast()
dtype = P.DType()
checktype = P.IsSubClass()
input_x = cast(input_x, mstype.float32) if not checktype(dtype(input_x), mstype.float_):
input_x = cast(input_x, mstype.float32)
return exp(input_x) return exp(input_x)
@ -51,8 +54,10 @@ def log_generic(input_x):
dtype = P.DType() dtype = P.DType()
shape = P.Shape() shape = P.Shape()
select = P.Select() select = P.Select()
checktype = P.IsSubClass()
input_x = cast(input_x, mstype.float32) if not checktype(dtype(input_x), mstype.float_):
input_x = cast(input_x, mstype.float32)
nan = fill(dtype(input_x), shape(input_x), np.nan) nan = fill(dtype(input_x), shape(input_x), np.nan)
inf = fill(dtype(input_x), shape(input_x), np.inf) inf = fill(dtype(input_x), shape(input_x), np.inf)
neg_x = less(input_x, 0.0) neg_x = less(input_x, 0.0)

View File

@ -222,7 +222,7 @@ class Bernoulli(Distribution):
pmf(k) = probs0 if k = 0; pmf(k) = probs0 if k = 0;
""" """
value = self._check_value(value, 'value') value = self._check_value(value, 'value')
value = self.cast(value, mstype.float32) value = self.cast(value, self.parameter_type)
probs1 = self._check_param_type(probs1) probs1 = self._check_param_type(probs1)
probs0 = 1.0 - probs1 probs0 = 1.0 - probs1
return self.log(probs1) * value + self.log(probs0) * (1.0 - value) return self.log(probs1) * value + self.log(probs0) * (1.0 - value)
@ -241,7 +241,7 @@ class Bernoulli(Distribution):
cdf(k) = 1 if k >=1; cdf(k) = 1 if k >=1;
""" """
value = self._check_value(value, 'value') value = self._check_value(value, 'value')
value = self.cast(value, mstype.float32) value = self.cast(value, self.parameter_type)
value = self.floor(value) value = self.floor(value)
probs1 = self._check_param_type(probs1) probs1 = self._check_param_type(probs1)
prob_type = self.dtypeop(probs1) prob_type = self.dtypeop(probs1)

View File

@ -225,7 +225,7 @@ class Geometric(Distribution):
pmf(k) = 0 if k < 0. pmf(k) = 0 if k < 0.
""" """
value = self._check_value(value, 'value') value = self._check_value(value, 'value')
value = self.cast(value, mstype.float32) value = self.cast(value, self.parameter_type)
value = self.floor(value) value = self.floor(value)
probs1 = self._check_param_type(probs1) probs1 = self._check_param_type(probs1)
pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1)) pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1))
@ -247,7 +247,7 @@ class Geometric(Distribution):
""" """
value = self._check_value(value, 'value') value = self._check_value(value, 'value')
value = self.cast(value, mstype.float32) value = self.cast(value, self.parameter_type)
value = self.floor(value) value = self.floor(value)
probs1 = self._check_param_type(probs1) probs1 = self._check_param_type(probs1)
probs0 = 1.0 - probs1 probs0 = 1.0 - probs1