forked from mindspore-Ecosystem/mindspore
!4952 Fix errors in log calculation logics
Merge pull request !4952 from peixu_ren/custom_pp_ops
This commit is contained in:
commit
b366608a3f
|
@ -15,24 +15,30 @@
|
|||
"""Utitly functions to help distribution class."""
|
||||
import numpy as np
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
def log_by_step(input_x):
|
||||
"""
|
||||
Log op on Ascend is calculated as log(abs(x)).
|
||||
Fix this with putting negative values as nan.
|
||||
"""
|
||||
select = P.Select()
|
||||
log = P.Log()
|
||||
less = P.Less()
|
||||
lessequal = P.LessEqual()
|
||||
fill = P.Fill()
|
||||
cast = P.Cast()
|
||||
dtype = P.DType()
|
||||
shape = P.Shape()
|
||||
select = P.Select()
|
||||
|
||||
input_x = cast(input_x, mstype.float32)
|
||||
nan = fill(dtype(input_x), shape(input_x), np.nan)
|
||||
inf = fill(dtype(input_x), shape(input_x), np.inf)
|
||||
neg_x = less(input_x, 0.0)
|
||||
nonpos_x = lessequal(input_x, 0.0)
|
||||
log_x = log(input_x)
|
||||
nan = fill(dtype(input_x), shape(input_x), np.nan)
|
||||
result = select(nonpos_x, nan, log_x)
|
||||
return result
|
||||
result = select(nonpos_x, -inf, log_x)
|
||||
return select(neg_x, nan, result)
|
||||
|
||||
def log1p_by_step(x):
|
||||
"""
|
||||
|
|
|
@ -157,30 +157,6 @@ def test_cross_entropy():
|
|||
ans = net(probs_b, probs_a)
|
||||
assert isinstance(ans, Tensor)
|
||||
|
||||
class BernoulliBasics(nn.Cell):
|
||||
"""
|
||||
Test class: basic mean/sd/var/mode/entropy function.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(BernoulliBasics, self).__init__()
|
||||
self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)
|
||||
|
||||
def construct(self):
|
||||
mean = self.b.mean()
|
||||
sd = self.b.sd()
|
||||
var = self.b.var()
|
||||
mode = self.b.mode()
|
||||
entropy = self.b.entropy()
|
||||
return mean + sd + var + mode + entropy
|
||||
|
||||
def test_bascis():
|
||||
"""
|
||||
Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
|
||||
"""
|
||||
net = BernoulliBasics()
|
||||
ans = net()
|
||||
assert isinstance(ans, Tensor)
|
||||
|
||||
class BernoulliConstruct(nn.Cell):
|
||||
"""
|
||||
Bernoulli distribution: going through construct.
|
||||
|
@ -205,3 +181,103 @@ def test_bernoulli_construct():
|
|||
probs = Tensor([0.5], dtype=dtype.float32)
|
||||
ans = net(value, probs)
|
||||
assert isinstance(ans, Tensor)
|
||||
|
||||
class BernoulliMean(nn.Cell):
|
||||
"""
|
||||
Test class: basic mean/sd/var/mode/entropy function.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(BernoulliMean, self).__init__()
|
||||
self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)
|
||||
|
||||
def construct(self):
|
||||
mean = self.b.mean()
|
||||
return mean
|
||||
|
||||
def test_mean():
|
||||
"""
|
||||
Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
|
||||
"""
|
||||
net = BernoulliMean()
|
||||
ans = net()
|
||||
assert isinstance(ans, Tensor)
|
||||
|
||||
class BernoulliSd(nn.Cell):
|
||||
"""
|
||||
Test class: basic mean/sd/var/mode/entropy function.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(BernoulliSd, self).__init__()
|
||||
self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)
|
||||
|
||||
def construct(self):
|
||||
sd = self.b.sd()
|
||||
return sd
|
||||
|
||||
def test_sd():
|
||||
"""
|
||||
Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
|
||||
"""
|
||||
net = BernoulliSd()
|
||||
ans = net()
|
||||
assert isinstance(ans, Tensor)
|
||||
|
||||
class BernoulliVar(nn.Cell):
|
||||
"""
|
||||
Test class: basic mean/sd/var/mode/entropy function.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(BernoulliVar, self).__init__()
|
||||
self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)
|
||||
|
||||
def construct(self):
|
||||
var = self.b.var()
|
||||
return var
|
||||
|
||||
def test_var():
|
||||
"""
|
||||
Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
|
||||
"""
|
||||
net = BernoulliVar()
|
||||
ans = net()
|
||||
assert isinstance(ans, Tensor)
|
||||
|
||||
class BernoulliMode(nn.Cell):
|
||||
"""
|
||||
Test class: basic mean/sd/var/mode/entropy function.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(BernoulliMode, self).__init__()
|
||||
self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)
|
||||
|
||||
def construct(self):
|
||||
mode = self.b.mode()
|
||||
return mode
|
||||
|
||||
def test_mode():
|
||||
"""
|
||||
Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
|
||||
"""
|
||||
net = BernoulliMode()
|
||||
ans = net()
|
||||
assert isinstance(ans, Tensor)
|
||||
|
||||
class BernoulliEntropy(nn.Cell):
|
||||
"""
|
||||
Test class: basic mean/sd/var/mode/entropy function.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(BernoulliEntropy, self).__init__()
|
||||
self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)
|
||||
|
||||
def construct(self):
|
||||
entropy = self.b.entropy()
|
||||
return entropy
|
||||
|
||||
def test_entropy():
|
||||
"""
|
||||
Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
|
||||
"""
|
||||
net = BernoulliEntropy()
|
||||
ans = net()
|
||||
assert isinstance(ans, Tensor)
|
||||
|
|
Loading…
Reference in New Issue