add get_seed() and set_seed()
This commit is contained in:
parent
c563336541
commit
4a5d115a66
|
@ -18,12 +18,14 @@ from .api import ms_function
|
|||
from .dtype import *
|
||||
from .parameter import Parameter, ParameterTuple
|
||||
from .tensor import MetaTensor, Tensor, RowTensor, SparseTensor
|
||||
from .seed import set_seed, get_seed
|
||||
|
||||
__all__ = [
|
||||
"MetaTensor", "Tensor", "RowTensor", "SparseTensor", # tensor
|
||||
'ms_function', # api
|
||||
'Parameter', 'ParameterTuple', # parameter
|
||||
"dtype"
|
||||
"dtype",
|
||||
"set_seed", "get_seed" # random seed
|
||||
]
|
||||
|
||||
__all__.extend(dtype.__all__)
|
||||
|
|
|
@ -23,6 +23,7 @@ from mindspore import log as logger
|
|||
|
||||
from . import dtype as mstype
|
||||
from .tensor import Tensor
|
||||
from .seed import get_seed
|
||||
from .._c_expression import random_normal
|
||||
|
||||
_INITIALIZER_ALIAS = dict()
|
||||
|
@ -71,7 +72,7 @@ class Initializer:
|
|||
|
||||
Args:
|
||||
slice_index (int): Slice index of a parameter's slices.
|
||||
Used when initialize a slice of a parameter, it guarantee that
|
||||
Used when initialize a slice of the parameter, it guarantee that
|
||||
devices use the same slice can generate the same tensor.
|
||||
shape (list[int]): Shape of the slice, used when initialize a slice of the parameter.
|
||||
"""
|
||||
|
@ -86,11 +87,17 @@ class Initializer:
|
|||
logger.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
if slice_index is not None:
|
||||
global_seed = get_seed()
|
||||
need_set_seed = ((slice_index is not None) and (global_seed is None))
|
||||
seed_saved = np.random.get_state()[1][0]
|
||||
if need_set_seed:
|
||||
np.random.seed(slice_index)
|
||||
self.__call__(arr)
|
||||
if need_set_seed:
|
||||
np.random.seed(seed_saved)
|
||||
return Tensor(arr, dtype=self.dtype)
|
||||
|
||||
|
||||
def _register(*aliases):
|
||||
"""Return the alias register."""
|
||||
def alias_reg(cls):
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Provide random seed api."""
|
||||
import numpy as np
|
||||
|
||||
# set global RNG seed
|
||||
_GLOBAL_SEED = None
|
||||
|
||||
def set_seed(seed):
|
||||
"""
|
||||
Set global random seed.
|
||||
|
||||
Note:
|
||||
The global seed is used by numpy.random, mindspore.common.Initializer, mindspore.ops.composite.random_ops and
|
||||
mindspore.nn.probability.distribution.
|
||||
If global seed is not set, these packages will use their own default seed independently, numpy.random and
|
||||
mindspore.common.Initializer will choose a random seed, mindspore.ops.composite.random_ops and
|
||||
mindspore.nn.probability.distribution will use zero.
|
||||
Seed set by numpy.random.seed() only used by numpy.random, while seed set by this API will also used by
|
||||
numpy.random, so just set all seed by this API is recommended.
|
||||
|
||||
Args:
|
||||
seed (int): The seed to be set.
|
||||
|
||||
Raises:
|
||||
ValueError: If seed is invalid (< 0).
|
||||
TypeError: If seed isn't a int.
|
||||
"""
|
||||
if not isinstance(seed, int):
|
||||
raise TypeError("The seed must be type of int.")
|
||||
if seed < 0:
|
||||
raise ValueError("The seed must be greater or equal to 0.")
|
||||
np.random.seed(seed)
|
||||
global _GLOBAL_SEED
|
||||
_GLOBAL_SEED = seed
|
||||
|
||||
|
||||
def get_seed():
|
||||
"""
|
||||
Get global random seed.
|
||||
"""
|
||||
return _GLOBAL_SEED
|
|
@ -27,7 +27,7 @@ class Bernoulli(Distribution):
|
|||
|
||||
Args:
|
||||
probs (float, list, numpy.ndarray, Tensor, Parameter): probability of 1 as outcome.
|
||||
seed (int): seed to use in sampling. Default: 0.
|
||||
seed (int): seed to use in sampling. Global seed is used if it is None. Default: None.
|
||||
dtype (mindspore.dtype): type of the distribution. Default: mstype.int32.
|
||||
name (str): name of the distribution. Default: Bernoulli.
|
||||
|
||||
|
@ -91,7 +91,7 @@ class Bernoulli(Distribution):
|
|||
|
||||
def __init__(self,
|
||||
probs=None,
|
||||
seed=0,
|
||||
seed=None,
|
||||
dtype=mstype.int32,
|
||||
name="Bernoulli"):
|
||||
"""
|
||||
|
|
|
@ -27,7 +27,7 @@ class Categorical(Distribution):
|
|||
Args:
|
||||
probs (Tensor, list, numpy.ndarray, Parameter, float): event probabilities.
|
||||
logits (Tensor, list, numpy.ndarray, Parameter, float): event log-odds.
|
||||
seed (int): seed to use in sampling. Default: 0.
|
||||
seed (int): seed to use in sampling. Global seed is used if it is None. Default: None.
|
||||
dtype (mstype.int32): type of the distribution. Default: mstype.int32.
|
||||
name (str): name of the distribution. Default: Categorical.
|
||||
|
||||
|
@ -67,7 +67,7 @@ class Categorical(Distribution):
|
|||
def __init__(self,
|
||||
probs=None,
|
||||
logits=None,
|
||||
seed=0,
|
||||
seed=None,
|
||||
dtype=mstype.int32,
|
||||
name="Categorical"):
|
||||
param = dict(locals())
|
||||
|
@ -83,7 +83,7 @@ class Categorical(Distribution):
|
|||
self.reshape = P.Reshape()
|
||||
self.div = P.RealDiv()
|
||||
self.size = P.Size()
|
||||
self.mutinomial = P.Multinomial(seed=seed)
|
||||
self.mutinomial = P.Multinomial(seed=self.seed)
|
||||
self.cast = P.Cast()
|
||||
self.expandim = P.ExpandDims()
|
||||
self.gather = P.GatherNd()
|
||||
|
|
|
@ -17,6 +17,7 @@ from mindspore import context
|
|||
from mindspore.nn.cell import Cell
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
from mindspore.common import get_seed
|
||||
from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param, cast_type_for_device
|
||||
from ._utils.utils import CheckTuple, CheckTensor
|
||||
|
||||
|
@ -26,7 +27,7 @@ class Distribution(Cell):
|
|||
Base class for all mathematical distributions.
|
||||
|
||||
Args:
|
||||
seed (int): random seed used in sampling.
|
||||
seed (int): random seed used in sampling. Global seed is used if it is None. Default: None.
|
||||
dtype (mindspore.dtype): the type of the event samples. Default: subclass dtype.
|
||||
name (str): Python str name prefixed to Ops created by this class. Default: subclass name.
|
||||
param (dict): parameters used to initialize the distribution.
|
||||
|
@ -56,6 +57,10 @@ class Distribution(Cell):
|
|||
Constructor of distribution class.
|
||||
"""
|
||||
super(Distribution, self).__init__()
|
||||
if seed is None:
|
||||
seed = get_seed()
|
||||
if seed is None:
|
||||
seed = 0
|
||||
validator.check_value_type('name', name, [str], type(self).__name__)
|
||||
validator.check_integer('seed', seed, 0, Rel.GE, name)
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ class Exponential(Distribution):
|
|||
|
||||
Args:
|
||||
rate (float, list, numpy.ndarray, Tensor, Parameter): inverse scale.
|
||||
seed (int): seed to use in sampling. Default: 0.
|
||||
seed (int): seed to use in sampling. Global seed is used if it is None. Default: None.
|
||||
dtype (mindspore.dtype): type of the distribution. Default: mstype.float32.
|
||||
name (str): name of the distribution. Default: Exponential.
|
||||
|
||||
|
@ -92,7 +92,7 @@ class Exponential(Distribution):
|
|||
|
||||
def __init__(self,
|
||||
rate=None,
|
||||
seed=0,
|
||||
seed=None,
|
||||
dtype=mstype.float32,
|
||||
name="Exponential"):
|
||||
"""
|
||||
|
|
|
@ -30,7 +30,7 @@ class Geometric(Distribution):
|
|||
|
||||
Args:
|
||||
probs (float, list, numpy.ndarray, Tensor, Parameter): probability of success.
|
||||
seed (int): seed to use in sampling. Default: 0.
|
||||
seed (int): seed to use in sampling. Global seed is used if it is None. Default: None.
|
||||
dtype (mindspore.dtype): type of the distribution. Default: mstype.int32.
|
||||
name (str): name of the distribution. Default: Geometric.
|
||||
|
||||
|
@ -94,7 +94,7 @@ class Geometric(Distribution):
|
|||
|
||||
def __init__(self,
|
||||
probs=None,
|
||||
seed=0,
|
||||
seed=None,
|
||||
dtype=mstype.int32,
|
||||
name="Geometric"):
|
||||
"""
|
||||
|
|
|
@ -29,7 +29,7 @@ class Normal(Distribution):
|
|||
Args:
|
||||
mean (int, float, list, numpy.ndarray, Tensor, Parameter): mean of the Normal distribution.
|
||||
sd (int, float, list, numpy.ndarray, Tensor, Parameter): stddev of the Normal distribution.
|
||||
seed (int): seed to use in sampling. Default: 0.
|
||||
seed (int): seed to use in sampling. Global seed is used if it is None. Default: None.
|
||||
dtype (mindspore.dtype): type of the distribution. Default: mstype.float32.
|
||||
name (str): name of the distribution. Default: Normal.
|
||||
|
||||
|
@ -94,7 +94,7 @@ class Normal(Distribution):
|
|||
def __init__(self,
|
||||
mean=None,
|
||||
sd=None,
|
||||
seed=0,
|
||||
seed=None,
|
||||
dtype=mstype.float32,
|
||||
name="Normal"):
|
||||
"""
|
||||
|
|
|
@ -62,7 +62,7 @@ class TransformedDistribution(Distribution):
|
|||
bijector,
|
||||
distribution,
|
||||
dtype,
|
||||
seed=0,
|
||||
seed=None,
|
||||
name="transformed_distribution"):
|
||||
"""
|
||||
Constructor of transformed_distribution class.
|
||||
|
|
|
@ -28,7 +28,7 @@ class Uniform(Distribution):
|
|||
Args:
|
||||
low (int, float, list, numpy.ndarray, Tensor, Parameter): lower bound of the distribution.
|
||||
high (int, float, list, numpy.ndarray, Tensor, Parameter): upper bound of the distribution.
|
||||
seed (int): seed to use in sampling. Default: 0.
|
||||
seed (int): seed to use in sampling. Global seed is used if it is None. Default: None.
|
||||
dtype (mindspore.dtype): type of the distribution. Default: mstype.float32.
|
||||
name (str): name of the distribution. Default: Uniform.
|
||||
|
||||
|
@ -93,7 +93,7 @@ class Uniform(Distribution):
|
|||
def __init__(self,
|
||||
low=None,
|
||||
high=None,
|
||||
seed=0,
|
||||
seed=None,
|
||||
dtype=mstype.float32,
|
||||
name="Uniform"):
|
||||
"""
|
||||
|
|
|
@ -26,7 +26,7 @@ from .clip_ops import clip_by_value
|
|||
from .multitype_ops.add_impl import hyper_add
|
||||
from .multitype_ops.ones_like_impl import ones_like
|
||||
from .multitype_ops.zeros_like_impl import zeros_like
|
||||
from .random_ops import set_seed, normal, uniform, gamma, poisson, multinomial
|
||||
from .random_ops import normal, uniform, gamma, poisson, multinomial
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
@ -41,7 +41,6 @@ __all__ = [
|
|||
'zeros_like',
|
||||
'ones_like',
|
||||
'zip_operation',
|
||||
'set_seed',
|
||||
'normal',
|
||||
'uniform',
|
||||
'gamma',
|
||||
|
|
|
@ -20,34 +20,14 @@ from .. import functional as F
|
|||
from ..primitive import constexpr
|
||||
from .multitype_ops import _constexpr_utils as const_utils
|
||||
from ...common import dtype as mstype
|
||||
|
||||
# set graph-level RNG seed
|
||||
_GRAPH_SEED = 0
|
||||
|
||||
@constexpr
|
||||
def set_seed(seed):
|
||||
"""
|
||||
Set the graph-level seed.
|
||||
Graph-level seed is used as a global variable, that can be used in different ops in case op-level seed is not set.
|
||||
If op-level seed is 0, use graph-level seed; if op-level seed is also 0, the system would generate a
|
||||
random seed.
|
||||
|
||||
Args:
|
||||
seed(Int): the graph-level seed value that to be set. Must be non-negative.
|
||||
|
||||
Examples:
|
||||
>>> C.set_seed(10)
|
||||
"""
|
||||
const_utils.check_non_negative("seed", seed, "set_seed")
|
||||
global _GRAPH_SEED
|
||||
_GRAPH_SEED = seed
|
||||
from ...common import get_seed as get_global_seed
|
||||
|
||||
@constexpr
|
||||
def get_seed():
|
||||
"""
|
||||
Get the graph-level seed.
|
||||
Graph-level seed is used as a global variable, that can be used in different ops in case op-level seed is not set.
|
||||
If op-level seed is 0, use graph-level seed; if op-level seed is also 0, the system would generate a
|
||||
If op-level seed is 0, use graph-level seed; if graph-level seed is also 0, the system would generate a
|
||||
random seed.
|
||||
|
||||
Returns:
|
||||
|
@ -56,7 +36,10 @@ def get_seed():
|
|||
Examples:
|
||||
>>> C.get_seed()
|
||||
"""
|
||||
return _GRAPH_SEED
|
||||
global_seed = get_global_seed()
|
||||
if global_seed is None:
|
||||
return 0
|
||||
return global_seed
|
||||
|
||||
def normal(shape, mean, stddev, seed=0):
|
||||
"""
|
||||
|
|
|
@ -19,6 +19,7 @@ import mindspore.nn as nn
|
|||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common import set_seed
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
|
@ -30,7 +31,7 @@ class Net(nn.Cell):
|
|||
self.seed = seed
|
||||
|
||||
def construct(self, alpha, beta):
|
||||
C.set_seed(20)
|
||||
set_seed(20)
|
||||
return C.gamma(self.shape, alpha, beta, self.seed)
|
||||
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ import mindspore.nn as nn
|
|||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common import set_seed
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
|
@ -30,7 +31,7 @@ class Net(nn.Cell):
|
|||
self.seed = seed
|
||||
|
||||
def construct(self, mean, stddev):
|
||||
C.set_seed(20)
|
||||
set_seed(20)
|
||||
return C.normal(self.shape, mean, stddev, self.seed)
|
||||
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ import mindspore.nn as nn
|
|||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common import set_seed
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
|
@ -30,7 +31,7 @@ class Net(nn.Cell):
|
|||
self.seed = seed
|
||||
|
||||
def construct(self, mean):
|
||||
C.set_seed(20)
|
||||
set_seed(20)
|
||||
return C.poisson(self.shape, mean, self.seed)
|
||||
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ import mindspore.nn as nn
|
|||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common import set_seed
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
|
@ -30,7 +31,7 @@ class Net(nn.Cell):
|
|||
self.seed = seed
|
||||
|
||||
def construct(self, minval, maxval):
|
||||
C.set_seed(20)
|
||||
set_seed(20)
|
||||
return C.uniform(self.shape, minval, maxval, self.seed)
|
||||
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ from mindspore import Tensor, Parameter
|
|||
import mindspore as ms
|
||||
import mindspore.common.api as me
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common import set_seed
|
||||
from hccl_test.manage.api import Hccl
|
||||
|
||||
class Net(nn.Cell):
|
||||
|
@ -112,5 +113,42 @@ def test_wrong_order_set_parallel_mode_without_initializer():
|
|||
net.set_auto_parallel()
|
||||
exe.compile(net, x, auto_parallel_mode=True, phase='train')
|
||||
|
||||
def test_check_initializer_weight_slice_seed(init_name="Uniform"):
|
||||
def get_slice(rank):
|
||||
set_seed(1)
|
||||
hccl = Hccl()
|
||||
rank_save = hccl.rank_id
|
||||
hccl.rank_id = rank
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((2, 1), (4, 1))
|
||||
strategy2 = ((2, 4),)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
exe = me._executor
|
||||
|
||||
x = Tensor(np.ones([32, 32]), dtype=ms.float32)
|
||||
weight = initializer(init_name, [64, 32], ms.float32)
|
||||
net = Net(strategy1, strategy2, weight)
|
||||
net.set_auto_parallel()
|
||||
exe.compile(net, x, auto_parallel_mode=True, phase='train')
|
||||
hccl.rank_id = rank_save
|
||||
return net.parameters_dict()['w1'].data.asnumpy()
|
||||
|
||||
|
||||
slice0 = get_slice(0)
|
||||
slice1 = get_slice(1)
|
||||
slice4 = get_slice(4)
|
||||
slice_shape = slice0.shape
|
||||
|
||||
slice0 = slice0.flatten()
|
||||
slice1 = slice1.flatten()
|
||||
slice4 = slice4.flatten()
|
||||
expect_slice_shape = (16, 32)
|
||||
|
||||
assert expect_slice_shape == slice_shape
|
||||
assert all(slice0 == slice4)
|
||||
assert all(slice0 == slice1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_initializer_weight_slice()
|
||||
|
|
Loading…
Reference in New Issue