forked from mindspore-Ecosystem/mindspore
!49823 fix get_seed graph mode bug
Merge pull request !49823 from YingtongHu/master
This commit is contained in:
commit
f1bb1d9130
|
@ -25,10 +25,11 @@ from mindspore import context
|
||||||
from mindspore import ops
|
from mindspore import ops
|
||||||
from mindspore.common import Tensor
|
from mindspore.common import Tensor
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
from mindspore.common.seed import get_seed, _get_graph_seed
|
from mindspore.common.seed import get_seed
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.ops import functional as F
|
from mindspore.ops import functional as F
|
||||||
from mindspore.ops.primitive import constexpr, _primexpr
|
from mindspore.ops.primitive import constexpr, _primexpr
|
||||||
|
from mindspore.ops.function.random_func import _get_seed
|
||||||
from mindspore.nn.layer.basic import tril as nn_tril
|
from mindspore.nn.layer.basic import tril as nn_tril
|
||||||
from mindspore.nn.layer.basic import triu as nn_triu
|
from mindspore.nn.layer.basic import triu as nn_triu
|
||||||
from mindspore._c_expression import Tensor as Tensor_
|
from mindspore._c_expression import Tensor as Tensor_
|
||||||
|
@ -455,7 +456,7 @@ def randn(*shape, dtype=mstype.float32):
|
||||||
size = _generate_shapes(shape)
|
size = _generate_shapes(shape)
|
||||||
seed = get_seed()
|
seed = get_seed()
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
seed1, seed2 = _get_graph_seed(seed, "StandardNormal")
|
seed1, seed2 = _get_seed(seed, "StandardNormal")
|
||||||
stdnormal = P.StandardNormal(seed=seed1, seed2=seed2)
|
stdnormal = P.StandardNormal(seed=seed1, seed2=seed2)
|
||||||
else:
|
else:
|
||||||
stdnormal = P.StandardNormal()
|
stdnormal = P.StandardNormal()
|
||||||
|
@ -496,7 +497,7 @@ def rand(*shape, dtype=mstype.float32):
|
||||||
size = _generate_shapes(shape)
|
size = _generate_shapes(shape)
|
||||||
seed = get_seed()
|
seed = get_seed()
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
seed1, seed2 = _get_graph_seed(seed, "UniformReal")
|
seed1, seed2 = _get_seed(seed, "UniformReal")
|
||||||
uniformreal = P.UniformReal(seed=seed1, seed2=seed2)
|
uniformreal = P.UniformReal(seed=seed1, seed2=seed2)
|
||||||
else:
|
else:
|
||||||
uniformreal = P.UniformReal()
|
uniformreal = P.UniformReal()
|
||||||
|
@ -566,7 +567,7 @@ def randint(minval, maxval=None, shape=None, dtype=mstype.int32):
|
||||||
shape = _check_shape(shape)
|
shape = _check_shape(shape)
|
||||||
seed = get_seed()
|
seed = get_seed()
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
seed1, seed2 = _get_graph_seed(seed, "UniformInt")
|
seed1, seed2 = _get_seed(seed, "UniformInt")
|
||||||
uniformint = P.UniformInt(seed=seed1, seed2=seed2)
|
uniformint = P.UniformInt(seed=seed1, seed2=seed2)
|
||||||
else:
|
else:
|
||||||
uniformint = P.UniformInt()
|
uniformint = P.UniformInt()
|
||||||
|
|
Loading…
Reference in New Issue