!49823 fix get_seed graph mode bug

Merge pull request !49823 from YingtongHu/master
This commit is contained in:
i-robot 2023-03-08 01:17:53 +00:00 committed by Gitee
commit f1bb1d9130
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 5 additions and 4 deletions

View File

@ -25,10 +25,11 @@ from mindspore import context
from mindspore import ops from mindspore import ops
from mindspore.common import Tensor from mindspore.common import Tensor
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.common.seed import get_seed, _get_graph_seed from mindspore.common.seed import get_seed
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops.primitive import constexpr, _primexpr from mindspore.ops.primitive import constexpr, _primexpr
from mindspore.ops.function.random_func import _get_seed
from mindspore.nn.layer.basic import tril as nn_tril from mindspore.nn.layer.basic import tril as nn_tril
from mindspore.nn.layer.basic import triu as nn_triu from mindspore.nn.layer.basic import triu as nn_triu
from mindspore._c_expression import Tensor as Tensor_ from mindspore._c_expression import Tensor as Tensor_
@ -455,7 +456,7 @@ def randn(*shape, dtype=mstype.float32):
size = _generate_shapes(shape) size = _generate_shapes(shape)
seed = get_seed() seed = get_seed()
if seed is not None: if seed is not None:
seed1, seed2 = _get_graph_seed(seed, "StandardNormal") seed1, seed2 = _get_seed(seed, "StandardNormal")
stdnormal = P.StandardNormal(seed=seed1, seed2=seed2) stdnormal = P.StandardNormal(seed=seed1, seed2=seed2)
else: else:
stdnormal = P.StandardNormal() stdnormal = P.StandardNormal()
@ -496,7 +497,7 @@ def rand(*shape, dtype=mstype.float32):
size = _generate_shapes(shape) size = _generate_shapes(shape)
seed = get_seed() seed = get_seed()
if seed is not None: if seed is not None:
seed1, seed2 = _get_graph_seed(seed, "UniformReal") seed1, seed2 = _get_seed(seed, "UniformReal")
uniformreal = P.UniformReal(seed=seed1, seed2=seed2) uniformreal = P.UniformReal(seed=seed1, seed2=seed2)
else: else:
uniformreal = P.UniformReal() uniformreal = P.UniformReal()
@ -566,7 +567,7 @@ def randint(minval, maxval=None, shape=None, dtype=mstype.int32):
shape = _check_shape(shape) shape = _check_shape(shape)
seed = get_seed() seed = get_seed()
if seed is not None: if seed is not None:
seed1, seed2 = _get_graph_seed(seed, "UniformInt") seed1, seed2 = _get_seed(seed, "UniformInt")
uniformint = P.UniformInt(seed=seed1, seed2=seed2) uniformint = P.UniformInt(seed=seed1, seed2=seed2)
else: else:
uniformint = P.UniformInt() uniformint = P.UniformInt()