forked from mindspore-Ecosystem/mindspore
fix bug
This commit is contained in:
parent
94ad76bc14
commit
4eecce7ab4
|
@ -25,10 +25,11 @@ from mindspore import context
|
|||
from mindspore import ops
|
||||
from mindspore.common import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.seed import get_seed, _get_graph_seed
|
||||
from mindspore.common.seed import get_seed
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops.primitive import constexpr, _primexpr
|
||||
from mindspore.ops.function.random_func import _get_seed
|
||||
from mindspore.nn.layer.basic import tril as nn_tril
|
||||
from mindspore.nn.layer.basic import triu as nn_triu
|
||||
from mindspore._c_expression import Tensor as Tensor_
|
||||
|
@ -455,7 +456,7 @@ def randn(*shape, dtype=mstype.float32):
|
|||
size = _generate_shapes(shape)
|
||||
seed = get_seed()
|
||||
if seed is not None:
|
||||
seed1, seed2 = _get_graph_seed(seed, "StandardNormal")
|
||||
seed1, seed2 = _get_seed(seed, "StandardNormal")
|
||||
stdnormal = P.StandardNormal(seed=seed1, seed2=seed2)
|
||||
else:
|
||||
stdnormal = P.StandardNormal()
|
||||
|
@ -496,7 +497,7 @@ def rand(*shape, dtype=mstype.float32):
|
|||
size = _generate_shapes(shape)
|
||||
seed = get_seed()
|
||||
if seed is not None:
|
||||
seed1, seed2 = _get_graph_seed(seed, "UniformReal")
|
||||
seed1, seed2 = _get_seed(seed, "UniformReal")
|
||||
uniformreal = P.UniformReal(seed=seed1, seed2=seed2)
|
||||
else:
|
||||
uniformreal = P.UniformReal()
|
||||
|
@ -566,7 +567,7 @@ def randint(minval, maxval=None, shape=None, dtype=mstype.int32):
|
|||
shape = _check_shape(shape)
|
||||
seed = get_seed()
|
||||
if seed is not None:
|
||||
seed1, seed2 = _get_graph_seed(seed, "UniformInt")
|
||||
seed1, seed2 = _get_seed(seed, "UniformInt")
|
||||
uniformint = P.UniformInt(seed=seed1, seed2=seed2)
|
||||
else:
|
||||
uniformint = P.UniformInt()
|
||||
|
|
Loading…
Reference in New Issue