forked from mindspore-Ecosystem/mindspore
fix multinomial
This commit is contained in:
parent
5adba834d0
commit
216ef0e144
|
@ -20,7 +20,6 @@ from .. import functional as F
|
|||
from ..primitive import constexpr
|
||||
from .multitype_ops import _constexpr_utils as const_utils
|
||||
from ...common import dtype as mstype
|
||||
from ...common.tensor import Tensor
|
||||
from ..._checkparam import Validator as validator
|
||||
from ..._checkparam import check_int_positive
|
||||
from ..._checkparam import Rel
|
||||
|
@ -134,9 +133,7 @@ def multinomial(inputs, num_sample=None, replacement=True, seed=0):
|
|||
n_dist = 1
|
||||
if len(shape(inputs)) > 1:
|
||||
n_dist = shape(inputs)[-2]
|
||||
a = Tensor(0.0, mstype.float32)
|
||||
b = Tensor(1.0, mstype.float32)
|
||||
random_uniform = P.UniformReal(seed=seed)((n_dist * num_sample,), a, b)
|
||||
random_uniform = P.UniformReal(seed=seed)((n_dist * num_sample,))
|
||||
if n_dist != 1:
|
||||
random_uniform = reshape(random_uniform, (n_dist, num_sample))
|
||||
vals = P.RealDiv()(P.Log()(random_uniform), inputs + 1e-6)
|
||||
|
|
|
@ -30,10 +30,6 @@ def test_multinomial():
|
|||
out0 = C.multinomial(x0, 1, True)
|
||||
out1 = C.multinomial(x0, 2, True)
|
||||
out2 = C.multinomial(x1, 6, True)
|
||||
out3 = C.multinomial(x0, 1, False)
|
||||
out4 = C.multinomial(x0, 2, False)
|
||||
assert out0.asnumpy().shape == (1,)
|
||||
assert out1.asnumpy().shape == (2,)
|
||||
assert out2.asnumpy().shape == (2, 6)
|
||||
assert out3.asnumpy().shape == (1,)
|
||||
assert out4.asnumpy().shape == (2,)
|
||||
|
|
Loading…
Reference in New Issue