fix multinomial

This commit is contained in:
baihuawei 2020-08-12 14:54:02 +08:00
parent 5adba834d0
commit 216ef0e144
2 changed files with 1 additions and 8 deletions

View File

@ -20,7 +20,6 @@ from .. import functional as F
from ..primitive import constexpr
from .multitype_ops import _constexpr_utils as const_utils
from ...common import dtype as mstype
from ...common.tensor import Tensor
from ..._checkparam import Validator as validator
from ..._checkparam import check_int_positive
from ..._checkparam import Rel
@ -134,9 +133,7 @@ def multinomial(inputs, num_sample=None, replacement=True, seed=0):
n_dist = 1
if len(shape(inputs)) > 1:
n_dist = shape(inputs)[-2]
a = Tensor(0.0, mstype.float32)
b = Tensor(1.0, mstype.float32)
random_uniform = P.UniformReal(seed=seed)((n_dist * num_sample,), a, b)
random_uniform = P.UniformReal(seed=seed)((n_dist * num_sample,))
if n_dist != 1:
random_uniform = reshape(random_uniform, (n_dist, num_sample))
vals = P.RealDiv()(P.Log()(random_uniform), inputs + 1e-6)

View File

@ -30,10 +30,6 @@ def test_multinomial():
out0 = C.multinomial(x0, 1, True)
out1 = C.multinomial(x0, 2, True)
out2 = C.multinomial(x1, 6, True)
out3 = C.multinomial(x0, 1, False)
out4 = C.multinomial(x0, 2, False)
assert out0.asnumpy().shape == (1,)
assert out1.asnumpy().shape == (2,)
assert out2.asnumpy().shape == (2, 6)
assert out3.asnumpy().shape == (1,)
assert out4.asnumpy().shape == (2,)