forked from mindspore-Ecosystem/mindspore
!5092 Add erf and erfc as generic functions for all the backend
Merge pull request !5092 from peixu_ren/custom_bijector
This commit is contained in:
commit
bb87ed6088
|
@ -17,14 +17,14 @@ from mindspore.ops import operations as P
|
|||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
from ..distribution._utils.utils import CheckTensor
|
||||
from ..distribution._utils.custom_ops import exp_by_step, expm1_by_step, log_by_step, log1p_by_step
|
||||
from ..distribution._utils.custom_ops import exp_generic, expm1_generic, log_generic, log1p_generic
|
||||
from .bijector import Bijector
|
||||
|
||||
|
||||
class PowerTransform(Bijector):
|
||||
r"""
|
||||
Power Bijector.
|
||||
This Bijector performs the operation: Y = g(X) = (1 + X * c)^(1 / c), X >= -1 / c, where c is power.
|
||||
This Bijector performs the operation: Y = g(X) = (1 + X * c)^(1 / c), X >= -1 / c, where c >= 0 is the power.
|
||||
|
||||
The power transform maps inputs from `[-1/c, inf]` to `[0, inf]`.
|
||||
|
||||
|
@ -61,10 +61,10 @@ class PowerTransform(Bijector):
|
|||
validator.check_number("power", power, 0, Rel.GE, self.name)
|
||||
self._power = power
|
||||
self.pow = P.Pow()
|
||||
self.exp = exp_by_step
|
||||
self.expm1 = expm1_by_step
|
||||
self.log = log_by_step
|
||||
self.log1p = log1p_by_step
|
||||
self.exp = exp_generic
|
||||
self.expm1 = expm1_generic
|
||||
self.log = log_generic
|
||||
self.log1p = log1p_generic
|
||||
|
||||
self.checktensor = CheckTensor()
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from ..distribution._utils.utils import cast_to_tensor, CheckTensor
|
||||
from ..distribution._utils.custom_ops import log_by_step
|
||||
from ..distribution._utils.custom_ops import log_generic
|
||||
from .bijector import Bijector
|
||||
|
||||
|
||||
|
@ -69,7 +69,7 @@ class ScalarAffine(Bijector):
|
|||
param=param)
|
||||
|
||||
self.abs = P.Abs()
|
||||
self.log = log_by_step
|
||||
self.log = log_generic
|
||||
|
||||
self.checktensor = CheckTensor()
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ from mindspore.common import dtype as mstype
|
|||
from mindspore.nn.layer.activation import LogSigmoid
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from ..distribution._utils.utils import cast_to_tensor, CheckTensor
|
||||
from ..distribution._utils.custom_ops import exp_by_step, expm1_by_step, log_by_step
|
||||
from ..distribution._utils.custom_ops import exp_generic, expm1_generic, log_generic
|
||||
from .bijector import Bijector
|
||||
|
||||
|
||||
|
@ -61,9 +61,9 @@ class Softplus(Bijector):
|
|||
super(Softplus, self).__init__(name=name, param=param)
|
||||
self._sharpness = cast_to_tensor(sharpness)
|
||||
|
||||
self.exp = exp_by_step
|
||||
self.log = log_by_step
|
||||
self.expm1 = expm1_by_step
|
||||
self.exp = exp_generic
|
||||
self.log = log_generic
|
||||
self.expm1 = expm1_generic
|
||||
self.abs = P.Abs()
|
||||
self.fill = P.Fill()
|
||||
self.greater = P.Greater()
|
||||
|
|
|
@ -28,8 +28,10 @@ __all__ = [
|
|||
'check_scalar_from_param',
|
||||
'check_prob',
|
||||
'check_type',
|
||||
'exp_by_step',
|
||||
'expm1_by_step',
|
||||
'log_by_step',
|
||||
'log1p_by_step',
|
||||
'exp_generic',
|
||||
'expm1_generic',
|
||||
'log_generic',
|
||||
'log1p_generic',
|
||||
'erf_generic',
|
||||
'erfc_generic',
|
||||
]
|
||||
|
|
|
@ -17,8 +17,7 @@ import numpy as np
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
||||
def exp_by_step(input_x):
|
||||
def exp_generic(input_x):
|
||||
"""
|
||||
Log op on Ascend doesn't supprot int types.
|
||||
Fix this with casting the type.
|
||||
|
@ -30,14 +29,14 @@ def exp_by_step(input_x):
|
|||
return exp(input_x)
|
||||
|
||||
|
||||
def expm1_by_step(input_x):
|
||||
def expm1_generic(input_x):
|
||||
"""
|
||||
Expm1 ops under GPU context.
|
||||
"""
|
||||
return exp_by_step(input_x) - 1.0
|
||||
return exp_generic(input_x) - 1.0
|
||||
|
||||
|
||||
def log_by_step(input_x):
|
||||
def log_generic(input_x):
|
||||
"""
|
||||
Log op on Ascend is calculated as log(abs(x)).
|
||||
Fix this with putting negative values as nan.
|
||||
|
@ -63,8 +62,166 @@ def log_by_step(input_x):
|
|||
return select(neg_x, nan, result)
|
||||
|
||||
|
||||
def log1p_by_step(x):
|
||||
def log1p_generic(x):
|
||||
"""
|
||||
Log1p ops on GPU device or when device_target == GPU.
|
||||
"""
|
||||
return log_by_step(x + 1.0)
|
||||
return log_generic(x + 1.0)
|
||||
|
||||
def _evaluate_polynomial(x, coefficients):
|
||||
poly = 0
|
||||
for co in coefficients:
|
||||
poly = poly * x + co
|
||||
return poly
|
||||
|
||||
def erf_f32_generic(x):
|
||||
"""
|
||||
Calculate erf for dtype of f32
|
||||
"""
|
||||
k_erf_tcoefficient = [+7.853861353153693e-5,
|
||||
-8.010193625184903e-4,
|
||||
+5.188327685732524e-3,
|
||||
-2.685381193529856e-2,
|
||||
+1.128358514861418e-1,
|
||||
-3.761262582423300e-1,
|
||||
+1.128379165726710e+0]
|
||||
|
||||
poly = _evaluate_polynomial(x * x, k_erf_tcoefficient)
|
||||
return x * poly
|
||||
|
||||
def erf_f64_generic(x):
|
||||
"""
|
||||
Calculate erf for dtype of f64
|
||||
"""
|
||||
k_erf_tcoefficient = [9.60497373987051638749e0,
|
||||
9.00260197203842689217e1,
|
||||
2.23200534594684319226e3,
|
||||
7.00332514112805075473e3,
|
||||
5.55923013010394962768e4]
|
||||
k_erf_ucoefficient = [1.00000000000000000000e0,
|
||||
3.35617141647503099647e1,
|
||||
5.21357949780152679795e2,
|
||||
4.59432382970980127987e3,
|
||||
2.26290000613890934246e4,
|
||||
4.92673942608635921086e4]
|
||||
z = x * x
|
||||
poly1 = _evaluate_polynomial(z, k_erf_tcoefficient)
|
||||
poly2 = _evaluate_polynomial(z, k_erf_ucoefficient)
|
||||
return x * poly1 / poly2
|
||||
|
||||
|
||||
def erfc_f32_generic(x):
|
||||
"""
|
||||
Calculate erfc for dtype of f32
|
||||
"""
|
||||
k_maxlog = 88.72283905206835
|
||||
k_erfc_pcoefficient = [+2.326819970068386e-2,
|
||||
-1.387039388740657e-1,
|
||||
+3.687424674597105e-1,
|
||||
-5.824733027278666e-1,
|
||||
+6.210004621745983e-1,
|
||||
-4.944515323274145e-1,
|
||||
+3.404879937665872e-1,
|
||||
-2.741127028184656e-1,
|
||||
+5.638259427386472e-1]
|
||||
k_erfc_rcoefficient = [-1.047766399936249e+1,
|
||||
+1.297719955372516e+1,
|
||||
-7.495518717768503e+0,
|
||||
+2.921019019210786e+0,
|
||||
-1.015265279202700e+0,
|
||||
+4.218463358204948e-1,
|
||||
-2.820767439740514e-1,
|
||||
+5.641895067754075e-1]
|
||||
abs_cal = P.Abs()
|
||||
select = P.Select()
|
||||
less = P.Less()
|
||||
fill = P.Fill()
|
||||
dtype = P.DType()
|
||||
shape = P.Shape()
|
||||
|
||||
abs_x = abs_cal(x)
|
||||
z = exp_generic(-x * x)
|
||||
q = 1 / abs_x
|
||||
y = q * q
|
||||
poly1 = _evaluate_polynomial(y, k_erfc_pcoefficient)
|
||||
poly2 = _evaluate_polynomial(y, k_erfc_rcoefficient)
|
||||
p = select(less(abs_x, 2.0), poly1, poly2)
|
||||
y = z * q * p
|
||||
zeros = fill(dtype(x), shape(x), 0)
|
||||
y_clamp = select(less(z, -k_maxlog), zeros, y)
|
||||
return select(less(x, 0), 2.0 - y_clamp, y_clamp)
|
||||
|
||||
|
||||
def erfc_f64_generic(x):
|
||||
"""
|
||||
Calculate erfc for dtype of f64
|
||||
"""
|
||||
k_maxlog = 7.09782712893383996843e2
|
||||
k_erfc_pcoefficient = [2.46196981473530512524e-10,
|
||||
5.64189564831068821977e-1,
|
||||
7.46321056442269912687e0,
|
||||
4.86371970985681366614e1,
|
||||
1.96520832956077098242e2,
|
||||
5.26445194995477358631e2,
|
||||
9.34528527171957607540e2,
|
||||
1.02755188689515710272e3,
|
||||
5.57535335369399327526e2]
|
||||
k_erfc_qcoefficient = [1.00000000000000000000e0,
|
||||
1.32281951154744992508e1,
|
||||
8.67072140885989742329e1,
|
||||
3.54937778887819891062e2,
|
||||
9.75708501743205489753e2,
|
||||
1.82390916687909736289e3,
|
||||
2.24633760818710981792e3,
|
||||
1.65666309194161350182e3,
|
||||
5.57535340817727675546e2]
|
||||
k_erfc_rcoefficient = [5.64189583547755073984e-1,
|
||||
1.27536670759978104416e0,
|
||||
5.01905042251180477414e0,
|
||||
6.16021097993053585195e0,
|
||||
7.40974269950448939160e0,
|
||||
2.97886665372100240670e0]
|
||||
k_erfc_scoefficient = [1.00000000000000000000e0,
|
||||
2.26052863220117276590e0,
|
||||
9.39603524938001434673e0,
|
||||
1.20489539808096656605e1,
|
||||
1.70814450747565897222e1,
|
||||
9.60896809063285878198e0,
|
||||
3.36907645100081516050e02]
|
||||
abs_cal = P.Abs()
|
||||
select = P.Select()
|
||||
less = P.Less()
|
||||
fill = P.Fill()
|
||||
dtype = P.DType()
|
||||
shape = P.Shape()
|
||||
|
||||
abs_x = abs_cal(x)
|
||||
z = -x * x
|
||||
exp_z = exp_generic(z)
|
||||
|
||||
temp1 = exp_z * _evaluate_polynomial(abs_x, k_erfc_pcoefficient) / _evaluate_polynomial(abs_x, k_erfc_qcoefficient)
|
||||
temp2 = exp_z * _evaluate_polynomial(abs_x, k_erfc_rcoefficient) / _evaluate_polynomial(abs_x, k_erfc_scoefficient)
|
||||
y = select(less(abs_x, 8.0), temp1, temp2)
|
||||
zeros = fill(dtype(x), shape(x), 0)
|
||||
y_clamp = select(less(z, k_maxlog), zeros, y)
|
||||
|
||||
poly2 = _evaluate_polynomial(y, k_erfc_rcoefficient)
|
||||
p = select(less(abs_x, 2.0), poly1, poly2)
|
||||
y = z * q * p
|
||||
zeros = fill(dtype(x), shape(x), 0)
|
||||
y_clamp = select(less(z, -k_maxlog), zeros, y)
|
||||
return select(less(x, 0), 2.0 - y_clamp, y_clamp)
|
||||
|
||||
def erfc_generic(x):
|
||||
select = P.Select()
|
||||
greater = P.Greater()
|
||||
abs_cal = P.Abs()
|
||||
|
||||
return select(greater(abs_cal(x), 1), erfc_f32_generic(x), 1 - erf_f32_generic(x))
|
||||
|
||||
def erf_generic(x):
|
||||
select = P.Select()
|
||||
less = P.Less()
|
||||
abs_cal = P.Abs()
|
||||
|
||||
return select(less(abs_cal(x), 1), erf_f32_generic(x), 1 - erfc_f32_generic(x))
|
||||
|
|
|
@ -18,7 +18,7 @@ from mindspore.ops import operations as P
|
|||
from mindspore.ops import composite as C
|
||||
from .distribution import Distribution
|
||||
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error
|
||||
from ._utils.custom_ops import exp_by_step, log_by_step
|
||||
from ._utils.custom_ops import exp_generic, log_generic, erf_generic
|
||||
|
||||
|
||||
class Bernoulli(Distribution):
|
||||
|
@ -109,13 +109,13 @@ class Bernoulli(Distribution):
|
|||
self._probs = probs
|
||||
|
||||
# ops needed for the class
|
||||
self.exp = exp_by_step
|
||||
self.log = log_by_step
|
||||
self.exp = exp_generic
|
||||
self.log = log_generic
|
||||
self.erf = erf_generic
|
||||
self.squeeze = P.Squeeze(0)
|
||||
self.cast = P.Cast()
|
||||
self.const = P.ScalarToArray()
|
||||
self.dtypeop = P.DType()
|
||||
self.erf = P.Erf()
|
||||
self.floor = P.Floor()
|
||||
self.fill = P.Fill()
|
||||
self.less = P.Less()
|
||||
|
|
|
@ -20,7 +20,7 @@ from mindspore.common import dtype as mstype
|
|||
from .distribution import Distribution
|
||||
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\
|
||||
raise_none_error
|
||||
from ._utils.custom_ops import exp_by_step, log_by_step
|
||||
from ._utils.custom_ops import exp_generic, log_generic
|
||||
|
||||
class Exponential(Distribution):
|
||||
"""
|
||||
|
@ -112,8 +112,8 @@ class Exponential(Distribution):
|
|||
self.minval = np.finfo(np.float).tiny
|
||||
|
||||
# ops needed for the class
|
||||
self.exp = exp_by_step
|
||||
self.log = log_by_step
|
||||
self.exp = exp_generic
|
||||
self.log = log_generic
|
||||
self.squeeze = P.Squeeze(0)
|
||||
self.cast = P.Cast()
|
||||
self.const = P.ScalarToArray()
|
||||
|
|
|
@ -20,7 +20,7 @@ from mindspore.common import dtype as mstype
|
|||
from .distribution import Distribution
|
||||
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\
|
||||
raise_none_error
|
||||
from ._utils.custom_ops import exp_by_step, log_by_step
|
||||
from ._utils.custom_ops import exp_generic, log_generic
|
||||
|
||||
|
||||
class Geometric(Distribution):
|
||||
|
@ -114,8 +114,8 @@ class Geometric(Distribution):
|
|||
self.minval = np.finfo(np.float).tiny
|
||||
|
||||
# ops needed for the class
|
||||
self.exp = exp_by_step
|
||||
self.log = log_by_step
|
||||
self.exp = exp_generic
|
||||
self.log = log_generic
|
||||
self.squeeze = P.Squeeze(0)
|
||||
self.cast = P.Cast()
|
||||
self.const = P.ScalarToArray()
|
||||
|
|
|
@ -20,7 +20,7 @@ from mindspore.common import dtype as mstype
|
|||
from .distribution import Distribution
|
||||
from ._utils.utils import convert_to_batch, check_greater_zero, check_type, check_distribution_name,\
|
||||
raise_none_error
|
||||
from ._utils.custom_ops import exp_by_step, expm1_by_step, log_by_step
|
||||
from ._utils.custom_ops import exp_generic, expm1_generic, log_generic, erf_generic
|
||||
|
||||
class Normal(Distribution):
|
||||
"""
|
||||
|
@ -114,13 +114,13 @@ class Normal(Distribution):
|
|||
self._sd_value = sd
|
||||
|
||||
#ops needed for the class
|
||||
self.exp = exp_by_step
|
||||
self.expm1 = expm1_by_step
|
||||
self.log = log_by_step
|
||||
self.exp = exp_generic
|
||||
self.expm1 = expm1_generic
|
||||
self.log = log_generic
|
||||
self.erf = erf_generic
|
||||
self.squeeze = P.Squeeze(0)
|
||||
self.cast = P.Cast()
|
||||
self.const = P.ScalarToArray()
|
||||
self.erf = P.Erf()
|
||||
self.fill = P.Fill()
|
||||
self.shape = P.Shape()
|
||||
self.sq = P.Square()
|
||||
|
|
|
@ -18,7 +18,7 @@ from mindspore.common import dtype as mstype
|
|||
import mindspore.nn as nn
|
||||
from .distribution import Distribution
|
||||
from ._utils.utils import check_type, raise_not_impl_error
|
||||
from ._utils.custom_ops import exp_by_step, log_by_step
|
||||
from ._utils.custom_ops import exp_generic, log_generic
|
||||
|
||||
class TransformedDistribution(Distribution):
|
||||
"""
|
||||
|
@ -55,8 +55,8 @@ class TransformedDistribution(Distribution):
|
|||
self._bijector = bijector
|
||||
self._distribution = distribution
|
||||
self._is_linear_transformation = bijector.is_constant_jacobian
|
||||
self.exp = exp_by_step
|
||||
self.log = log_by_step
|
||||
self.exp = exp_generic
|
||||
self.log = log_generic
|
||||
|
||||
@property
|
||||
def bijector(self):
|
||||
|
|
|
@ -19,7 +19,7 @@ from mindspore.common import dtype as mstype
|
|||
from .distribution import Distribution
|
||||
from ._utils.utils import convert_to_batch, check_greater, check_type, check_distribution_name,\
|
||||
raise_none_error
|
||||
from ._utils.custom_ops import exp_by_step, log_by_step
|
||||
from ._utils.custom_ops import exp_generic, log_generic
|
||||
|
||||
class Uniform(Distribution):
|
||||
"""
|
||||
|
@ -113,8 +113,8 @@ class Uniform(Distribution):
|
|||
self._high = high
|
||||
|
||||
# ops needed for the class
|
||||
self.exp = exp_by_step
|
||||
self.log = log_by_step
|
||||
self.exp = exp_generic
|
||||
self.log = log_generic
|
||||
self.squeeze = P.Squeeze(0)
|
||||
self.cast = P.Cast()
|
||||
self.const = P.ScalarToArray()
|
||||
|
|
Loading…
Reference in New Issue