!49823 fix get_seed graph mode bug

Merge pull request !49823 from YingtongHu/master
This commit is contained in:
i-robot 2023-03-08 01:17:53 +00:00 committed by Gitee
commit f1bb1d9130
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 5 additions and 4 deletions

View File

@ -25,10 +25,11 @@ from mindspore import context
from mindspore import ops
from mindspore.common import Tensor
from mindspore.common import dtype as mstype
from mindspore.common.seed import get_seed, _get_graph_seed
from mindspore.common.seed import get_seed
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.primitive import constexpr, _primexpr
from mindspore.ops.function.random_func import _get_seed
from mindspore.nn.layer.basic import tril as nn_tril
from mindspore.nn.layer.basic import triu as nn_triu
from mindspore._c_expression import Tensor as Tensor_
@ -455,7 +456,7 @@ def randn(*shape, dtype=mstype.float32):
size = _generate_shapes(shape)
seed = get_seed()
if seed is not None:
seed1, seed2 = _get_graph_seed(seed, "StandardNormal")
seed1, seed2 = _get_seed(seed, "StandardNormal")
stdnormal = P.StandardNormal(seed=seed1, seed2=seed2)
else:
stdnormal = P.StandardNormal()
@ -496,7 +497,7 @@ def rand(*shape, dtype=mstype.float32):
size = _generate_shapes(shape)
seed = get_seed()
if seed is not None:
seed1, seed2 = _get_graph_seed(seed, "UniformReal")
seed1, seed2 = _get_seed(seed, "UniformReal")
uniformreal = P.UniformReal(seed=seed1, seed2=seed2)
else:
uniformreal = P.UniformReal()
@ -566,7 +567,7 @@ def randint(minval, maxval=None, shape=None, dtype=mstype.int32):
shape = _check_shape(shape)
seed = get_seed()
if seed is not None:
seed1, seed2 = _get_graph_seed(seed, "UniformInt")
seed1, seed2 = _get_seed(seed, "UniformInt")
uniformint = P.UniformInt(seed=seed1, seed2=seed2)
else:
uniformint = P.UniformInt()