forked from mindspore-Ecosystem/mindspore
!6600 Update the convention that random seed works.
Merge pull request !6600 from peixu_ren/custom_bijector
This commit is contained in:
commit
467ed2ccd0
|
@ -18,7 +18,7 @@ from .api import ms_function
|
|||
from .dtype import *
|
||||
from .parameter import Parameter, ParameterTuple
|
||||
from .tensor import MetaTensor, Tensor, RowTensor, SparseTensor
|
||||
from .seed import set_seed, get_seed
|
||||
from .seed import set_seed, get_seed, _truncate_seed, _update_seeds, _get_op_seed
|
||||
|
||||
|
||||
__all__ = dtype.__all__
|
||||
|
@ -27,5 +27,5 @@ __all__.extend([
|
|||
'ms_function', # api
|
||||
'Parameter', 'ParameterTuple', # parameter
|
||||
"dtype",
|
||||
"set_seed", "get_seed" # random seed
|
||||
"set_seed", "get_seed", '_truncate_seed', '_update_seeds', '_get_op_seed' # random seed
|
||||
])
|
||||
|
|
|
@ -16,8 +16,13 @@
|
|||
import numpy as np
|
||||
import mindspore.dataset as de
|
||||
|
||||
# constants
|
||||
_MAXINT32 = 2**31 - 1
|
||||
keyConstant = [3528531795, 2654435769, 3449720151, 3144134277]
|
||||
|
||||
# set global RNG seed
|
||||
_GLOBAL_SEED = None
|
||||
_KERNEL_SEED = {}
|
||||
|
||||
def set_seed(seed):
|
||||
"""
|
||||
|
@ -25,12 +30,12 @@ def set_seed(seed):
|
|||
|
||||
Note:
|
||||
The global seed is used by numpy.random, mindspore.common.Initializer, mindspore.ops.composite.random_ops and
|
||||
mindspore.nn.probability.distribution.
|
||||
mindspore.nn.probability.distribution.
|
||||
If global seed is not set, these packages will use their own default seed independently, numpy.random and
|
||||
mindspore.common.Initializer will choose a random seed, mindspore.ops.composite.random_ops and
|
||||
mindspore.nn.probability.distribution will use zero.
|
||||
mindspore.common.Initializer will choose a random seed, mindspore.ops.composite.random_ops and
|
||||
mindspore.nn.probability.distribution will use zero.
|
||||
Seed set by numpy.random.seed() only used by numpy.random, while seed set by this API will also used by
|
||||
numpy.random, so just set all seed by this API is recommended.
|
||||
numpy.random, so just set all seed by this API is recommended.
|
||||
|
||||
Args:
|
||||
seed (int): The seed to be set.
|
||||
|
@ -45,6 +50,7 @@ def set_seed(seed):
|
|||
raise ValueError("The seed must be greater or equal to 0.")
|
||||
np.random.seed(seed)
|
||||
de.config.set_seed(seed)
|
||||
_reset_op_seed()
|
||||
global _GLOBAL_SEED
|
||||
_GLOBAL_SEED = seed
|
||||
|
||||
|
@ -54,3 +60,51 @@ def get_seed():
|
|||
Get global random seed.
|
||||
"""
|
||||
return _GLOBAL_SEED
|
||||
|
||||
|
||||
def _truncate_seed(seed):
|
||||
"""
|
||||
Truncate the seed with MAXINT32.
|
||||
|
||||
Args:
|
||||
seed (int): The seed to be truncated.
|
||||
"""
|
||||
return seed % _MAXINT32 # Truncate to fit into 32-bit integer
|
||||
|
||||
|
||||
def _update_seeds(op_seed, kernel_name):
|
||||
"""
|
||||
Update the seed every time when a random op is called.
|
||||
|
||||
Args:
|
||||
seed (int): The op-seed to be updated.
|
||||
kernel_name (string): The random op kernel.
|
||||
"""
|
||||
global _GLOBAL_SEED
|
||||
global _KERNEL_SEED
|
||||
if _GLOBAL_SEED is not None:
|
||||
_GLOBAL_SEED += keyConstant[1] + keyConstant[3] * (2**8)
|
||||
if op_seed is not None:
|
||||
_KERNEL_SEED[(kernel_name, op_seed)] = _KERNEL_SEED[(kernel_name, op_seed)] + (keyConstant[0] ^ keyConstant[2])
|
||||
|
||||
|
||||
def _get_op_seed(op_seed, kernel_name):
|
||||
"""
|
||||
Get op seed which is relating to the specific kernel.
|
||||
If the seed does not exist, add it into the kernel's dictionary.
|
||||
|
||||
Args:
|
||||
seed (int): The op-seed to be updated.
|
||||
kernel_name (string): The random op kernel.
|
||||
"""
|
||||
if ((kernel_name, op_seed) not in _KERNEL_SEED) or (_KERNEL_SEED[(kernel_name, op_seed)] == -1):
|
||||
_KERNEL_SEED[(kernel_name, op_seed)] = op_seed
|
||||
_KERNEL_SEED[(kernel_name, op_seed)] = 0
|
||||
return _KERNEL_SEED[(kernel_name, op_seed)]
|
||||
|
||||
def _reset_op_seed():
|
||||
"""
|
||||
Reset op seeds in the kernel's dictionary.
|
||||
"""
|
||||
for key in _KERNEL_SEED:
|
||||
_KERNEL_SEED[key] = -1
|
||||
|
|
|
@ -71,7 +71,7 @@ def check_int_positive(arg_name, arg_value, op_name):
|
|||
|
||||
|
||||
@constexpr
|
||||
def check_non_negative(arg_name, arg_value, op_name):
|
||||
def check_int_non_negative(arg_name, arg_value, op_name):
|
||||
"""Int type judgment."""
|
||||
if isinstance(arg_value, int):
|
||||
if arg_value >= 0:
|
||||
|
|
|
@ -21,27 +21,45 @@ from ..primitive import constexpr
|
|||
from .multitype_ops import _constexpr_utils as const_utils
|
||||
from ...common import dtype as mstype
|
||||
from ...common import get_seed as get_global_seed
|
||||
from ...common import _truncate_seed, _update_seeds, _get_op_seed
|
||||
|
||||
@constexpr
|
||||
def get_seed():
|
||||
def get_seed(op_seed, kernel_name):
|
||||
"""
|
||||
Get the graph-level seed.
|
||||
Graph-level seed is used as a global variable, that can be used in different ops in case op-level seed is not set.
|
||||
If op-level seed is 0, use graph-level seed; if graph-level seed is also 0, the system would generate a
|
||||
random seed.
|
||||
|
||||
Note:
|
||||
For each seed, either op-seed or graph-seed, a random sequence will be generated relating to this seed.
|
||||
So, the state of the seed regarding to this op should be recorded.
|
||||
A simple illustration should be:
|
||||
If a random op is called twice within one program, the two results should be different:
|
||||
print(C.uniform((1, 4), seed=1)) # generates 'A1'
|
||||
print(C.uniform((1, 4), seed=1)) # generates 'A2'
|
||||
If the same program runs again, it repeat the results:
|
||||
print(C.uniform((1, 4), seed=1)) # generates 'A1'
|
||||
print(C.uniform((1, 4), seed=1)) # generates 'A2'
|
||||
|
||||
Returns:
|
||||
Interger. The current graph-level seed.
|
||||
|
||||
Examples:
|
||||
>>> C.get_seed()
|
||||
>>> C.get_seed(seed, 'normal')
|
||||
"""
|
||||
global_seed = get_global_seed()
|
||||
if global_seed is None:
|
||||
return 0
|
||||
return global_seed
|
||||
global_seed = 0
|
||||
if op_seed is None:
|
||||
temp_seed = _get_op_seed(0, kernel_name)
|
||||
else:
|
||||
temp_seed = _get_op_seed(op_seed, kernel_name)
|
||||
seeds = _truncate_seed(global_seed), _truncate_seed(temp_seed)
|
||||
_update_seeds(op_seed, kernel_name)
|
||||
return seeds
|
||||
|
||||
def normal(shape, mean, stddev, seed=0):
|
||||
def normal(shape, mean, stddev, seed=None):
|
||||
"""
|
||||
Generates random numbers according to the Normal (or Gaussian) random number distribution.
|
||||
|
||||
|
@ -52,7 +70,7 @@ def normal(shape, mean, stddev, seed=0):
|
|||
stddev (Tensor): The deviation σ distribution parameter. It should be greater than 0.
|
||||
with float32 data type.
|
||||
seed (int): Seed is used as entropy source for the Random number engines to generate pseudo-random numbers.
|
||||
must be non-negative. Default: 0.
|
||||
must be non-negative. Default: None, which will be treated as 0.
|
||||
|
||||
Returns:
|
||||
Tensor. The shape should be equal to the broadcasted shape between the input `shape` and shapes
|
||||
|
@ -69,15 +87,14 @@ def normal(shape, mean, stddev, seed=0):
|
|||
stddev_dtype = F.dtype(stddev)
|
||||
const_utils.check_tensors_dtype_same(mean_dtype, mstype.float32, "normal")
|
||||
const_utils.check_tensors_dtype_same(stddev_dtype, mstype.float32, "normal")
|
||||
const_utils.check_non_negative("seed", seed, "normal")
|
||||
seed1 = get_seed()
|
||||
seed2 = seed
|
||||
seed1, seed2 = get_seed(seed, "normal")
|
||||
const_utils.check_int_non_negative("seed", seed2, "normal")
|
||||
stdnormal = P.StandardNormal(seed1, seed2)
|
||||
random_normal = stdnormal(shape)
|
||||
value = random_normal * stddev + mean
|
||||
return value
|
||||
|
||||
def laplace(shape, mean, lambda_param, seed=0):
|
||||
def laplace(shape, mean, lambda_param, seed=None):
|
||||
r"""
|
||||
Generates random numbers according to the Laplace random number distribution.
|
||||
It is defined as:
|
||||
|
@ -92,7 +109,7 @@ def laplace(shape, mean, lambda_param, seed=0):
|
|||
lambda_param (Tensor): The parameter used for controling the variance of this random distribution. The
|
||||
variance of Laplace distribution is equal to twice the square of lambda_param. With float32 data type.
|
||||
seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers.
|
||||
Default: 0.
|
||||
Default: None, which will be treated as 0.
|
||||
|
||||
Returns:
|
||||
Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of mean and lambda_param.
|
||||
|
@ -108,14 +125,14 @@ def laplace(shape, mean, lambda_param, seed=0):
|
|||
lambda_param_dtype = F.dtype(lambda_param)
|
||||
const_utils.check_tensors_dtype_same(mean_dtype, mstype.float32, "laplace")
|
||||
const_utils.check_tensors_dtype_same(lambda_param_dtype, mstype.float32, "laplace")
|
||||
seed1 = get_seed()
|
||||
seed2 = seed
|
||||
seed1, seed2 = get_seed(seed, "laplace")
|
||||
const_utils.check_int_non_negative("seed", seed2, "laplace")
|
||||
stdlaplace = P.StandardLaplace(seed1, seed2)
|
||||
rnd = stdlaplace(shape)
|
||||
value = rnd * lambda_param + mean
|
||||
return value
|
||||
|
||||
def uniform(shape, minval, maxval, seed=0, dtype=mstype.float32):
|
||||
def uniform(shape, minval, maxval, seed=None, dtype=mstype.float32):
|
||||
"""
|
||||
Generates random numbers according to the Uniform random number distribution.
|
||||
|
||||
|
@ -131,7 +148,7 @@ def uniform(shape, minval, maxval, seed=0, dtype=mstype.float32):
|
|||
It defines the maximum possible generated value, with int32 or float32 data type.
|
||||
If dtype is int32, only one number is allowed.
|
||||
seed (int): Seed is used as entropy source for the random number engines to generate pseudo-random numbers,
|
||||
must be non-negative. Default: 0.
|
||||
must be non-negative. Default: None, which will be treated as 0.
|
||||
dtype (mindspore.dtype): type of the Uniform distribution. If it is int32, it generates numbers from discrete
|
||||
uniform distribution; if it is float32, it generates numbers from continuous uniform distribution. It only
|
||||
supports these two data types. Default: mstype.float32.
|
||||
|
@ -159,9 +176,8 @@ def uniform(shape, minval, maxval, seed=0, dtype=mstype.float32):
|
|||
const_utils.check_valid_type(dtype, [mstype.int32, mstype.float32], 'uniform')
|
||||
const_utils.check_tensors_dtype_same(minval_dtype, dtype, "uniform")
|
||||
const_utils.check_tensors_dtype_same(maxval_dtype, dtype, "uniform")
|
||||
const_utils.check_non_negative("seed", seed, "uniform")
|
||||
seed1 = get_seed()
|
||||
seed2 = seed
|
||||
seed1, seed2 = get_seed(seed, "uniform")
|
||||
const_utils.check_int_non_negative("seed", seed2, "uniform")
|
||||
if const_utils.is_same_type(dtype, mstype.int32):
|
||||
random_uniform = P.UniformInt(seed1, seed2)
|
||||
value = random_uniform(shape, minval, maxval)
|
||||
|
@ -171,7 +187,7 @@ def uniform(shape, minval, maxval, seed=0, dtype=mstype.float32):
|
|||
value = random_uniform * (maxval - minval) + minval
|
||||
return value
|
||||
|
||||
def gamma(shape, alpha, beta, seed=0):
|
||||
def gamma(shape, alpha, beta, seed=None):
|
||||
"""
|
||||
Generates random numbers according to the Gamma random number distribution.
|
||||
|
||||
|
@ -180,7 +196,7 @@ def gamma(shape, alpha, beta, seed=0):
|
|||
alpha (Tensor): The alpha α distribution parameter. It should be greater than 0 with float32 data type.
|
||||
beta (Tensor): The beta β distribution parameter. It should be greater than 0 with float32 data type.
|
||||
seed (int): Seed is used as entropy source for the random number engines to generate
|
||||
pseudo-random numbers, must be non-negative. Default: 0.
|
||||
pseudo-random numbers, must be non-negative. Default: None, which will be treated as 0.
|
||||
|
||||
Returns:
|
||||
Tensor. The shape should be equal to the broadcasted shape between the input "shape" and shapes
|
||||
|
@ -193,14 +209,13 @@ def gamma(shape, alpha, beta, seed=0):
|
|||
>>> beta = Tensor(1.0, mstype.float32)
|
||||
>>> output = C.gamma(shape, alpha, beta, seed=5)
|
||||
"""
|
||||
const_utils.check_non_negative("seed", seed, "gamma")
|
||||
seed1 = get_seed()
|
||||
seed2 = seed
|
||||
seed1, seed2 = get_seed(seed, "gamma")
|
||||
const_utils.check_int_non_negative("seed", seed2, "gamma")
|
||||
random_gamma = P.Gamma(seed1, seed2)
|
||||
value = random_gamma(shape, alpha, beta)
|
||||
return value
|
||||
|
||||
def poisson(shape, mean, seed=0):
|
||||
def poisson(shape, mean, seed=None):
|
||||
"""
|
||||
Generates random numbers according to the Poisson random number distribution.
|
||||
|
||||
|
@ -208,7 +223,7 @@ def poisson(shape, mean, seed=0):
|
|||
shape (tuple): The shape of random tensor to be generated.
|
||||
mean (Tensor): The mean μ distribution parameter. It should be greater than 0 with float32 data type.
|
||||
seed (int): Seed is used as entropy source for the random number engines to generate pseudo-random numbers
|
||||
and must be non-negative. Default: 0.
|
||||
and must be non-negative. Default: None, which will be treated as 0.
|
||||
|
||||
Returns:
|
||||
Tensor. The shape should be equal to the broadcasted shape between the input "shape" and shapes of `mean`.
|
||||
|
@ -219,9 +234,8 @@ def poisson(shape, mean, seed=0):
|
|||
>>> mean = Tensor(1.0, mstype.float32)
|
||||
>>> output = C.poisson(shape, mean, seed=5)
|
||||
"""
|
||||
const_utils.check_non_negative("seed", seed, "poisson")
|
||||
seed1 = get_seed()
|
||||
seed2 = seed
|
||||
seed1, seed2 = get_seed(seed, "poisson")
|
||||
const_utils.check_int_non_negative("seed", seed2, "poisson")
|
||||
random_poisson = P.Poisson(seed1, seed2)
|
||||
value = random_poisson(shape, mean)
|
||||
return value
|
||||
|
|
Loading…
Reference in New Issue