modify_normal_seed

This commit is contained in:
lilei 2020-10-18 12:43:11 +08:00
parent c3d1f54c7f
commit 04075671cf
5 changed files with 49 additions and 22 deletions

View File

@ -19,7 +19,7 @@ import math
from functools import reduce from functools import reduce
import numpy as np import numpy as np
from scipy.stats import truncnorm from scipy.stats import truncnorm
from .seed import _get_graph_seed
from . import dtype as mstype from . import dtype as mstype
from .tensor import Tensor, MetaTensor from .tensor import Tensor, MetaTensor
from .._c_expression import random_normal from .._c_expression import random_normal
@ -40,8 +40,19 @@ class Initializer:
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
self._kwargs = kwargs self._kwargs = kwargs
self.shape = None self._seed = None
self.dtype = None
@property
def seed(self):
seed_ = self._seed if self._seed is not None else 1
_, seed = _get_graph_seed(seed_, "init")
return seed
@seed.setter
def seed(self, value):
if not isinstance(value, int):
raise TypeError("'value' must be int type.")
self._seed = value
def _initialize(self, *kwargs): def _initialize(self, *kwargs):
raise NotImplementedError('Must be overridden!') raise NotImplementedError('Must be overridden!')
@ -353,7 +364,7 @@ class Normal(Initializer):
self.sigma = sigma self.sigma = sigma
def _initialize(self, arr): def _initialize(self, arr):
seed = np.random.get_state()[1][0] seed = self.seed
output_tensor = Tensor(np.zeros(arr.shape, dtype=np.float32)) output_tensor = Tensor(np.zeros(arr.shape, dtype=np.float32))
random_normal(0, self.sigma, arr.shape, seed, output_tensor) random_normal(0, self.sigma, arr.shape, seed, output_tensor)
output_data = output_tensor.asnumpy() output_data = output_tensor.asnumpy()
@ -434,8 +445,7 @@ def initializer(init, shape=None, dtype=mstype.float32):
elif isinstance(init, numbers.Number): elif isinstance(init, numbers.Number):
init = Constant(init) init = Constant(init)
shape = shape if shape is not None else init.shape shape = shape if shape is not None else init.shape
dtype = init.dtype if init.dtype is not None else dtype init_obj = MetaTensor(dtype, shape, init)
init_obj = MetaTensor(init, dtype, shape)
return init_obj return init_obj
__all__ = [ __all__ = [

View File

@ -30,7 +30,7 @@ def _reset_op_seed():
""" """
Reset op seeds in the kernel's dictionary. Reset op seeds in the kernel's dictionary.
""" """
for kernel_name, op_seed in _KERNEL_SEED.items(): for (kernel_name, op_seed) in _KERNEL_SEED:
_KERNEL_SEED[(kernel_name, op_seed)] = op_seed _KERNEL_SEED[(kernel_name, op_seed)] = op_seed

View File

@ -404,7 +404,7 @@ class MetaTensor(MetaTensor_):
Returns: Returns:
Array, an array after being initialized. Array, an array after being initialized.
""" """
def __init__(self, init, dtype, shape): def __init__(self, dtype, shape, init=None):
#check param #check param
self.init = init self.init = init
MetaTensor_.__init__(self, dtype, shape) MetaTensor_.__init__(self, dtype, shape)
@ -419,6 +419,9 @@ class MetaTensor(MetaTensor_):
using the same slice can generate the same tensor. using the same slice can generate the same tensor.
shape (list[int]): Shape of the slice, it is used when initialize a slice of the parameter. shape (list[int]): Shape of the slice, it is used when initialize a slice of the parameter.
""" """
if self.init is None:
raise TypeError("to_dense must be set MetaTensor.init, init can't be None")
if shape is None: if shape is None:
shape = self.shape shape = self.shape
@ -428,15 +431,28 @@ class MetaTensor(MetaTensor_):
msg = "Error shape={}".format(shape) msg = "Error shape={}".format(shape)
logger.error(msg) logger.error(msg)
raise ValueError(msg) raise ValueError(msg)
class seed_context:
'''set and restore seed'''
def __init__(self, init):
self.init = init
from .seed import get_seed from .seed import get_seed
global_seed = get_seed() global_seed = get_seed()
need_set_seed = ((slice_index is not None) and (global_seed is None)) self._np_seed = np.random.get_state()[1][0]
seed_saved = np.random.get_state()[1][0] self.need_set_seed = ((slice_index is not None) and (global_seed is None))
if need_set_seed: self.seed = self.init.seed
def __enter__(self):
if self.need_set_seed:
np.random.seed(slice_index) np.random.seed(slice_index)
self.init.seed = slice_index
def __exit__(self, ptype, value, trace):
if self.need_set_seed:
np.random.seed(self._np_seed)
self.init.seed = self.seed
with seed_context(self.init):
self.init(arr) self.init(arr)
if need_set_seed:
np.random.seed(seed_saved)
return Tensor(arr, dtype=self.dtype) return Tensor(arr, dtype=self.dtype)

View File

@ -24,6 +24,7 @@ from mindspore import Tensor
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Adam from mindspore.nn.optim import Adam
from mindspore.common import set_seed
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.common.initializer import TruncatedNormal from mindspore.common.initializer import TruncatedNormal
from mindspore.parallel._ps_context import _is_role_pserver, _is_role_worker from mindspore.parallel._ps_context import _is_role_pserver, _is_role_worker
@ -95,12 +96,12 @@ def do_sparse_embedding(ps=False):
envs = os.environ envs = os.environ
if __name__ == "__main__": if __name__ == "__main__":
np.random.seed(0) set_seed(0)
ps_loss = do_sparse_embedding(True) ps_loss = do_sparse_embedding(True)
if _is_role_worker(): if _is_role_worker():
context.reset_ps_context() context.reset_ps_context()
np.random.seed(0) set_seed(0)
no_ps_loss = do_sparse_embedding() no_ps_loss = do_sparse_embedding()
context.set_ps_context(enable_ps=True) context.set_ps_context(enable_ps=True)

View File

@ -14,7 +14,7 @@
import numpy as np import numpy as np
from numpy import allclose from numpy import allclose
from mindspore.common import set_seed
import mindspore.common.initializer as init import mindspore.common.initializer as init
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Parameter from mindspore import Parameter
@ -40,10 +40,10 @@ class ParameterNet(nn.Cell):
def test_using_same_seed_for_initializer(): def test_using_same_seed_for_initializer():
np.random.seed(0) set_seed(0)
net1 = ParameterNet() net1 = ParameterNet()
net1.init_parameters_data() net1.init_parameters_data()
np.random.seed(0) set_seed(0)
net2 = ParameterNet() net2 = ParameterNet()
net2.init_parameters_data() net2.init_parameters_data()
for key in net1.parameters_dict(): for key in net1.parameters_dict():