modify_normal_seed

This commit is contained in:
lilei 2020-10-18 12:43:11 +08:00
parent c3d1f54c7f
commit 04075671cf
5 changed files with 49 additions and 22 deletions

View File

@ -19,7 +19,7 @@ import math
from functools import reduce
import numpy as np
from scipy.stats import truncnorm
from .seed import _get_graph_seed
from . import dtype as mstype
from .tensor import Tensor, MetaTensor
from .._c_expression import random_normal
@ -40,8 +40,19 @@ class Initializer:
"""
def __init__(self, **kwargs):
self._kwargs = kwargs
self.shape = None
self.dtype = None
self._seed = None
@property
def seed(self):
seed_ = self._seed if self._seed is not None else 1
_, seed = _get_graph_seed(seed_, "init")
return seed
@seed.setter
def seed(self, value):
if not isinstance(value, int):
raise TypeError("'value' must be int type.")
self._seed = value
def _initialize(self, *kwargs):
raise NotImplementedError('Must be overridden!')
@ -353,7 +364,7 @@ class Normal(Initializer):
self.sigma = sigma
def _initialize(self, arr):
seed = np.random.get_state()[1][0]
seed = self.seed
output_tensor = Tensor(np.zeros(arr.shape, dtype=np.float32))
random_normal(0, self.sigma, arr.shape, seed, output_tensor)
output_data = output_tensor.asnumpy()
@ -434,8 +445,7 @@ def initializer(init, shape=None, dtype=mstype.float32):
elif isinstance(init, numbers.Number):
init = Constant(init)
shape = shape if shape is not None else init.shape
dtype = init.dtype if init.dtype is not None else dtype
init_obj = MetaTensor(init, dtype, shape)
init_obj = MetaTensor(dtype, shape, init)
return init_obj
__all__ = [

View File

@ -30,7 +30,7 @@ def _reset_op_seed():
"""
Reset op seeds in the kernel's dictionary.
"""
for kernel_name, op_seed in _KERNEL_SEED.items():
for (kernel_name, op_seed) in _KERNEL_SEED:
_KERNEL_SEED[(kernel_name, op_seed)] = op_seed

View File

@ -404,7 +404,7 @@ class MetaTensor(MetaTensor_):
Returns:
Array, an array after being initialized.
"""
def __init__(self, init, dtype, shape):
def __init__(self, dtype, shape, init=None):
#check param
self.init = init
MetaTensor_.__init__(self, dtype, shape)
@ -419,6 +419,9 @@ class MetaTensor(MetaTensor_):
using the same slice can generate the same tensor.
shape (list[int]): Shape of the slice, it is used when initialize a slice of the parameter.
"""
if self.init is None:
raise TypeError("to_dense must be set MetaTensor.init, init can't be None")
if shape is None:
shape = self.shape
@ -428,15 +431,28 @@ class MetaTensor(MetaTensor_):
msg = "Error shape={}".format(shape)
logger.error(msg)
raise ValueError(msg)
from .seed import get_seed
global_seed = get_seed()
need_set_seed = ((slice_index is not None) and (global_seed is None))
seed_saved = np.random.get_state()[1][0]
if need_set_seed:
np.random.seed(slice_index)
self.init(arr)
if need_set_seed:
np.random.seed(seed_saved)
class seed_context:
'''set and restore seed'''
def __init__(self, init):
self.init = init
from .seed import get_seed
global_seed = get_seed()
self._np_seed = np.random.get_state()[1][0]
self.need_set_seed = ((slice_index is not None) and (global_seed is None))
self.seed = self.init.seed
def __enter__(self):
if self.need_set_seed:
np.random.seed(slice_index)
self.init.seed = slice_index
def __exit__(self, ptype, value, trace):
if self.need_set_seed:
np.random.seed(self._np_seed)
self.init.seed = self.seed
with seed_context(self.init):
self.init(arr)
return Tensor(arr, dtype=self.dtype)

View File

@ -24,6 +24,7 @@ from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Adam
from mindspore.common import set_seed
from mindspore.ops import operations as P
from mindspore.common.initializer import TruncatedNormal
from mindspore.parallel._ps_context import _is_role_pserver, _is_role_worker
@ -95,12 +96,12 @@ def do_sparse_embedding(ps=False):
envs = os.environ
if __name__ == "__main__":
np.random.seed(0)
set_seed(0)
ps_loss = do_sparse_embedding(True)
if _is_role_worker():
context.reset_ps_context()
np.random.seed(0)
set_seed(0)
no_ps_loss = do_sparse_embedding()
context.set_ps_context(enable_ps=True)

View File

@ -14,7 +14,7 @@
import numpy as np
from numpy import allclose
from mindspore.common import set_seed
import mindspore.common.initializer as init
import mindspore.nn as nn
from mindspore import Parameter
@ -40,10 +40,10 @@ class ParameterNet(nn.Cell):
def test_using_same_seed_for_initializer():
np.random.seed(0)
set_seed(0)
net1 = ParameterNet()
net1.init_parameters_data()
np.random.seed(0)
set_seed(0)
net2 = ParameterNet()
net2.init_parameters_data()
for key in net1.parameters_dict():