fix numpy rand, randn, randint

This commit is contained in:
Yingtong Hu 2023-03-01 10:40:20 +08:00
parent 011210061e
commit eac88ebd67
2 changed files with 22 additions and 7 deletions

View File

@ -25,7 +25,7 @@ from mindspore import context
from mindspore import ops
from mindspore.common import Tensor
from mindspore.common import dtype as mstype
from mindspore.common.seed import get_seed
from mindspore.common.seed import get_seed, _get_graph_seed
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.primitive import constexpr, _primexpr
@ -455,7 +455,8 @@ def randn(*shape, dtype=mstype.float32):
size = _generate_shapes(shape)
seed = get_seed()
if seed is not None:
stdnormal = P.StandardNormal(seed=seed)
seed1, seed2 = _get_graph_seed(seed, "StandardNormal")
stdnormal = P.StandardNormal(seed=seed1, seed2=seed2)
else:
stdnormal = P.StandardNormal()
return stdnormal(size).astype(dtype)
@ -495,7 +496,8 @@ def rand(*shape, dtype=mstype.float32):
size = _generate_shapes(shape)
seed = get_seed()
if seed is not None:
uniformreal = P.UniformReal(seed=seed)
seed1, seed2 = _get_graph_seed(seed, "UniformReal")
uniformreal = P.UniformReal(seed=seed1, seed2=seed2)
else:
uniformreal = P.UniformReal()
return uniformreal(size).astype(dtype)
@ -564,7 +566,8 @@ def randint(minval, maxval=None, shape=None, dtype=mstype.int32):
shape = _check_shape(shape)
seed = get_seed()
if seed is not None:
uniformint = P.UniformInt(seed=seed)
seed1, seed2 = _get_graph_seed(seed, "UniformInt")
uniformint = P.UniformInt(seed=seed1, seed2=seed2)
else:
uniformint = P.UniformInt()
t_min = _type_convert(Tensor, minval).astype(dtype)

View File

@ -903,7 +903,11 @@ def test_randn():
set_seed(1)
t1 = mnp.randn(1, 2, 3)
t2 = mnp.randn(1, 2, 3)
assert (t1.asnumpy() == t2.asnumpy()).all()
assert onp.array_equal(t1.asnumpy(), t2.asnumpy()) is False
set_seed(1)
t3 = mnp.randn(1, 2, 3)
assert (t1.asnumpy() == t3.asnumpy()).all()
with pytest.raises(ValueError):
mnp.randn(dtype="int32")
@ -931,7 +935,11 @@ def test_rand():
set_seed(1)
t1 = mnp.rand(1, 2, 3)
t2 = mnp.rand(1, 2, 3)
assert (t1.asnumpy() == t2.asnumpy()).all()
assert onp.array_equal(t1.asnumpy(), t2.asnumpy()) is False
set_seed(1)
t3 = mnp.rand(1, 2, 3)
assert (t1.asnumpy() == t3.asnumpy()).all()
with pytest.raises(ValueError):
mnp.rand(dtype="int32")
@ -958,7 +966,11 @@ def test_randint():
set_seed(1)
t1 = mnp.randint(1, 5, 3)
t2 = mnp.randint(1, 5, 3)
assert (t1.asnumpy() == t2.asnumpy()).all()
assert onp.array_equal(t1.asnumpy(), t2.asnumpy()) is False
set_seed(1)
t3 = mnp.randint(1, 5, 3)
assert (t1.asnumpy() == t3.asnumpy()).all()
with pytest.raises(TypeError):
mnp.randint(1.2)