forked from mindspore-Ecosystem/mindspore
fix numpy rand, randn, randint
This commit is contained in:
parent
011210061e
commit
eac88ebd67
|
@ -25,7 +25,7 @@ from mindspore import context
|
|||
from mindspore import ops
|
||||
from mindspore.common import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.seed import get_seed
|
||||
from mindspore.common.seed import get_seed, _get_graph_seed
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops.primitive import constexpr, _primexpr
|
||||
|
@ -455,7 +455,8 @@ def randn(*shape, dtype=mstype.float32):
|
|||
size = _generate_shapes(shape)
|
||||
seed = get_seed()
|
||||
if seed is not None:
|
||||
stdnormal = P.StandardNormal(seed=seed)
|
||||
seed1, seed2 = _get_graph_seed(seed, "StandardNormal")
|
||||
stdnormal = P.StandardNormal(seed=seed1, seed2=seed2)
|
||||
else:
|
||||
stdnormal = P.StandardNormal()
|
||||
return stdnormal(size).astype(dtype)
|
||||
|
@ -495,7 +496,8 @@ def rand(*shape, dtype=mstype.float32):
|
|||
size = _generate_shapes(shape)
|
||||
seed = get_seed()
|
||||
if seed is not None:
|
||||
uniformreal = P.UniformReal(seed=seed)
|
||||
seed1, seed2 = _get_graph_seed(seed, "UniformReal")
|
||||
uniformreal = P.UniformReal(seed=seed1, seed2=seed2)
|
||||
else:
|
||||
uniformreal = P.UniformReal()
|
||||
return uniformreal(size).astype(dtype)
|
||||
|
@ -564,7 +566,8 @@ def randint(minval, maxval=None, shape=None, dtype=mstype.int32):
|
|||
shape = _check_shape(shape)
|
||||
seed = get_seed()
|
||||
if seed is not None:
|
||||
uniformint = P.UniformInt(seed=seed)
|
||||
seed1, seed2 = _get_graph_seed(seed, "UniformInt")
|
||||
uniformint = P.UniformInt(seed=seed1, seed2=seed2)
|
||||
else:
|
||||
uniformint = P.UniformInt()
|
||||
t_min = _type_convert(Tensor, minval).astype(dtype)
|
||||
|
|
|
@ -903,7 +903,11 @@ def test_randn():
|
|||
set_seed(1)
|
||||
t1 = mnp.randn(1, 2, 3)
|
||||
t2 = mnp.randn(1, 2, 3)
|
||||
assert (t1.asnumpy() == t2.asnumpy()).all()
|
||||
assert onp.array_equal(t1.asnumpy(), t2.asnumpy()) is False
|
||||
|
||||
set_seed(1)
|
||||
t3 = mnp.randn(1, 2, 3)
|
||||
assert (t1.asnumpy() == t3.asnumpy()).all()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
mnp.randn(dtype="int32")
|
||||
|
@ -931,7 +935,11 @@ def test_rand():
|
|||
set_seed(1)
|
||||
t1 = mnp.rand(1, 2, 3)
|
||||
t2 = mnp.rand(1, 2, 3)
|
||||
assert (t1.asnumpy() == t2.asnumpy()).all()
|
||||
assert onp.array_equal(t1.asnumpy(), t2.asnumpy()) is False
|
||||
|
||||
set_seed(1)
|
||||
t3 = mnp.rand(1, 2, 3)
|
||||
assert (t1.asnumpy() == t3.asnumpy()).all()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
mnp.rand(dtype="int32")
|
||||
|
@ -958,7 +966,11 @@ def test_randint():
|
|||
set_seed(1)
|
||||
t1 = mnp.randint(1, 5, 3)
|
||||
t2 = mnp.randint(1, 5, 3)
|
||||
assert (t1.asnumpy() == t2.asnumpy()).all()
|
||||
assert onp.array_equal(t1.asnumpy(), t2.asnumpy()) is False
|
||||
|
||||
set_seed(1)
|
||||
t3 = mnp.randint(1, 5, 3)
|
||||
assert (t1.asnumpy() == t3.asnumpy()).all()
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
mnp.randint(1.2)
|
||||
|
|
Loading…
Reference in New Issue