modify_normal_seed
This commit is contained in:
parent
c3d1f54c7f
commit
04075671cf
|
@ -19,7 +19,7 @@ import math
|
|||
from functools import reduce
|
||||
import numpy as np
|
||||
from scipy.stats import truncnorm
|
||||
|
||||
from .seed import _get_graph_seed
|
||||
from . import dtype as mstype
|
||||
from .tensor import Tensor, MetaTensor
|
||||
from .._c_expression import random_normal
|
||||
|
@ -40,8 +40,19 @@ class Initializer:
|
|||
"""
|
||||
def __init__(self, **kwargs):
|
||||
self._kwargs = kwargs
|
||||
self.shape = None
|
||||
self.dtype = None
|
||||
self._seed = None
|
||||
|
||||
@property
|
||||
def seed(self):
|
||||
seed_ = self._seed if self._seed is not None else 1
|
||||
_, seed = _get_graph_seed(seed_, "init")
|
||||
return seed
|
||||
|
||||
@seed.setter
|
||||
def seed(self, value):
|
||||
if not isinstance(value, int):
|
||||
raise TypeError("'value' must be int type.")
|
||||
self._seed = value
|
||||
|
||||
def _initialize(self, *kwargs):
|
||||
raise NotImplementedError('Must be overridden!')
|
||||
|
@ -353,7 +364,7 @@ class Normal(Initializer):
|
|||
self.sigma = sigma
|
||||
|
||||
def _initialize(self, arr):
|
||||
seed = np.random.get_state()[1][0]
|
||||
seed = self.seed
|
||||
output_tensor = Tensor(np.zeros(arr.shape, dtype=np.float32))
|
||||
random_normal(0, self.sigma, arr.shape, seed, output_tensor)
|
||||
output_data = output_tensor.asnumpy()
|
||||
|
@ -434,8 +445,7 @@ def initializer(init, shape=None, dtype=mstype.float32):
|
|||
elif isinstance(init, numbers.Number):
|
||||
init = Constant(init)
|
||||
shape = shape if shape is not None else init.shape
|
||||
dtype = init.dtype if init.dtype is not None else dtype
|
||||
init_obj = MetaTensor(init, dtype, shape)
|
||||
init_obj = MetaTensor(dtype, shape, init)
|
||||
return init_obj
|
||||
|
||||
__all__ = [
|
||||
|
|
|
@ -30,7 +30,7 @@ def _reset_op_seed():
|
|||
"""
|
||||
Reset op seeds in the kernel's dictionary.
|
||||
"""
|
||||
for kernel_name, op_seed in _KERNEL_SEED.items():
|
||||
for (kernel_name, op_seed) in _KERNEL_SEED:
|
||||
_KERNEL_SEED[(kernel_name, op_seed)] = op_seed
|
||||
|
||||
|
||||
|
|
|
@ -404,7 +404,7 @@ class MetaTensor(MetaTensor_):
|
|||
Returns:
|
||||
Array, an array after being initialized.
|
||||
"""
|
||||
def __init__(self, init, dtype, shape):
|
||||
def __init__(self, dtype, shape, init=None):
|
||||
#check param
|
||||
self.init = init
|
||||
MetaTensor_.__init__(self, dtype, shape)
|
||||
|
@ -419,6 +419,9 @@ class MetaTensor(MetaTensor_):
|
|||
using the same slice can generate the same tensor.
|
||||
shape (list[int]): Shape of the slice, it is used when initialize a slice of the parameter.
|
||||
"""
|
||||
if self.init is None:
|
||||
raise TypeError("to_dense must be set MetaTensor.init, init can't be None")
|
||||
|
||||
if shape is None:
|
||||
shape = self.shape
|
||||
|
||||
|
@ -428,15 +431,28 @@ class MetaTensor(MetaTensor_):
|
|||
msg = "Error shape={}".format(shape)
|
||||
logger.error(msg)
|
||||
raise ValueError(msg)
|
||||
from .seed import get_seed
|
||||
global_seed = get_seed()
|
||||
need_set_seed = ((slice_index is not None) and (global_seed is None))
|
||||
seed_saved = np.random.get_state()[1][0]
|
||||
if need_set_seed:
|
||||
np.random.seed(slice_index)
|
||||
self.init(arr)
|
||||
if need_set_seed:
|
||||
np.random.seed(seed_saved)
|
||||
class seed_context:
|
||||
'''set and restore seed'''
|
||||
def __init__(self, init):
|
||||
self.init = init
|
||||
from .seed import get_seed
|
||||
global_seed = get_seed()
|
||||
self._np_seed = np.random.get_state()[1][0]
|
||||
self.need_set_seed = ((slice_index is not None) and (global_seed is None))
|
||||
self.seed = self.init.seed
|
||||
|
||||
def __enter__(self):
|
||||
if self.need_set_seed:
|
||||
np.random.seed(slice_index)
|
||||
self.init.seed = slice_index
|
||||
|
||||
def __exit__(self, ptype, value, trace):
|
||||
if self.need_set_seed:
|
||||
np.random.seed(self._np_seed)
|
||||
self.init.seed = self.seed
|
||||
|
||||
with seed_context(self.init):
|
||||
self.init(arr)
|
||||
return Tensor(arr, dtype=self.dtype)
|
||||
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ from mindspore import Tensor
|
|||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore.nn.optim import Adam
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.initializer import TruncatedNormal
|
||||
from mindspore.parallel._ps_context import _is_role_pserver, _is_role_worker
|
||||
|
@ -95,12 +96,12 @@ def do_sparse_embedding(ps=False):
|
|||
|
||||
envs = os.environ
|
||||
if __name__ == "__main__":
|
||||
np.random.seed(0)
|
||||
set_seed(0)
|
||||
ps_loss = do_sparse_embedding(True)
|
||||
|
||||
if _is_role_worker():
|
||||
context.reset_ps_context()
|
||||
np.random.seed(0)
|
||||
set_seed(0)
|
||||
no_ps_loss = do_sparse_embedding()
|
||||
context.set_ps_context(enable_ps=True)
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
|
||||
import numpy as np
|
||||
from numpy import allclose
|
||||
|
||||
from mindspore.common import set_seed
|
||||
import mindspore.common.initializer as init
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Parameter
|
||||
|
@ -40,10 +40,10 @@ class ParameterNet(nn.Cell):
|
|||
|
||||
|
||||
def test_using_same_seed_for_initializer():
|
||||
np.random.seed(0)
|
||||
set_seed(0)
|
||||
net1 = ParameterNet()
|
||||
net1.init_parameters_data()
|
||||
np.random.seed(0)
|
||||
set_seed(0)
|
||||
net2 = ParameterNet()
|
||||
net2.init_parameters_data()
|
||||
for key in net1.parameters_dict():
|
||||
|
|
Loading…
Reference in New Issue