forked from mindspore-Ecosystem/mindspore
commit
1406ac43c3
|
@ -20,7 +20,6 @@ from .. import functional as F
|
||||||
from ..primitive import constexpr
|
from ..primitive import constexpr
|
||||||
from .multitype_ops import _constexpr_utils as const_utils
|
from .multitype_ops import _constexpr_utils as const_utils
|
||||||
from ...common import dtype as mstype
|
from ...common import dtype as mstype
|
||||||
from ...common.tensor import Tensor
|
|
||||||
from ..._checkparam import Validator as validator
|
from ..._checkparam import Validator as validator
|
||||||
from ..._checkparam import check_int_positive
|
from ..._checkparam import check_int_positive
|
||||||
from ..._checkparam import Rel
|
from ..._checkparam import Rel
|
||||||
|
@ -134,9 +133,7 @@ def multinomial(inputs, num_sample=None, replacement=True, seed=0):
|
||||||
n_dist = 1
|
n_dist = 1
|
||||||
if len(shape(inputs)) > 1:
|
if len(shape(inputs)) > 1:
|
||||||
n_dist = shape(inputs)[-2]
|
n_dist = shape(inputs)[-2]
|
||||||
a = Tensor(0.0, mstype.float32)
|
random_uniform = P.UniformReal(seed=seed)((n_dist * num_sample,))
|
||||||
b = Tensor(1.0, mstype.float32)
|
|
||||||
random_uniform = P.UniformReal(seed=seed)((n_dist * num_sample,), a, b)
|
|
||||||
if n_dist != 1:
|
if n_dist != 1:
|
||||||
random_uniform = reshape(random_uniform, (n_dist, num_sample))
|
random_uniform = reshape(random_uniform, (n_dist, num_sample))
|
||||||
vals = P.RealDiv()(P.Log()(random_uniform), inputs + 1e-6)
|
vals = P.RealDiv()(P.Log()(random_uniform), inputs + 1e-6)
|
||||||
|
|
|
@ -30,10 +30,6 @@ def test_multinomial():
|
||||||
out0 = C.multinomial(x0, 1, True)
|
out0 = C.multinomial(x0, 1, True)
|
||||||
out1 = C.multinomial(x0, 2, True)
|
out1 = C.multinomial(x0, 2, True)
|
||||||
out2 = C.multinomial(x1, 6, True)
|
out2 = C.multinomial(x1, 6, True)
|
||||||
out3 = C.multinomial(x0, 1, False)
|
|
||||||
out4 = C.multinomial(x0, 2, False)
|
|
||||||
assert out0.asnumpy().shape == (1,)
|
assert out0.asnumpy().shape == (1,)
|
||||||
assert out1.asnumpy().shape == (2,)
|
assert out1.asnumpy().shape == (2,)
|
||||||
assert out2.asnumpy().shape == (2, 6)
|
assert out2.asnumpy().shape == (2, 6)
|
||||||
assert out3.asnumpy().shape == (1,)
|
|
||||||
assert out4.asnumpy().shape == (2,)
|
|
||||||
|
|
Loading…
Reference in New Issue