forked from OSSInnovation/mindspore
delete attribute seed of Initializer
This commit is contained in:
parent
60de9089ba
commit
eae69a386a
|
@ -41,7 +41,6 @@ class Initializer:
|
||||||
self._kwargs = kwargs
|
self._kwargs = kwargs
|
||||||
self.shape = None
|
self.shape = None
|
||||||
self.dtype = None
|
self.dtype = None
|
||||||
self._seed = None
|
|
||||||
|
|
||||||
def _initialize(self, *kwargs):
|
def _initialize(self, *kwargs):
|
||||||
raise NotImplementedError('Must be overridden!')
|
raise NotImplementedError('Must be overridden!')
|
||||||
|
@ -49,15 +48,6 @@ class Initializer:
|
||||||
def __call__(self, arr):
|
def __call__(self, arr):
|
||||||
return self._initialize(arr)
|
return self._initialize(arr)
|
||||||
|
|
||||||
@property
|
|
||||||
def seed(self):
|
|
||||||
return self._seed
|
|
||||||
|
|
||||||
@seed.setter
|
|
||||||
def seed(self, seed_):
|
|
||||||
"""set the random seed."""
|
|
||||||
self._seed = seed_
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self):
|
def shape(self):
|
||||||
return self._shape
|
return self._shape
|
||||||
|
@ -74,8 +64,15 @@ class Initializer:
|
||||||
def dtype(self, dtype):
|
def dtype(self, dtype):
|
||||||
self._dtype = dtype
|
self._dtype = dtype
|
||||||
|
|
||||||
def to_tensor(self):
|
def to_tensor(self, slice_index=None):
|
||||||
"""Get the tensor format data of this Initializer."""
|
"""
|
||||||
|
Get the tensor format data of this Initializer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
slice_index (int): Slice index of a parameter's slices.
|
||||||
|
Used when initialize a slice of a parameter, it guarantee that
|
||||||
|
devices use the same slice can generate the same tensor.
|
||||||
|
"""
|
||||||
arr = None
|
arr = None
|
||||||
try:
|
try:
|
||||||
arr = np.ndarray(self.shape)
|
arr = np.ndarray(self.shape)
|
||||||
|
@ -83,10 +80,10 @@ class Initializer:
|
||||||
msg = "Error shape={}".format(self.shape)
|
msg = "Error shape={}".format(self.shape)
|
||||||
logger.error(msg)
|
logger.error(msg)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
if self._seed is not None:
|
|
||||||
np.random.seed(self.seed)
|
if slice_index is not None:
|
||||||
|
np.random.seed(slice_index)
|
||||||
self.__call__(arr)
|
self.__call__(arr)
|
||||||
self._seed = None
|
|
||||||
return Tensor(arr, dtype=self.dtype)
|
return Tensor(arr, dtype=self.dtype)
|
||||||
|
|
||||||
def _register(*aliases):
|
def _register(*aliases):
|
||||||
|
|
|
@ -22,7 +22,7 @@ from .initializer import initializer, Initializer
|
||||||
from .tensor import Tensor, MetaTensor
|
from .tensor import Tensor, MetaTensor
|
||||||
from .._checkparam import _check_str_by_regular
|
from .._checkparam import _check_str_by_regular
|
||||||
from ..parallel._utils import _set_clone_info, _CloneInfo
|
from ..parallel._utils import _set_clone_info, _CloneInfo
|
||||||
from ..parallel._tensor import _get_seed
|
from ..parallel._tensor import _get_slice_index
|
||||||
|
|
||||||
__all__ = ['Parameter', 'ParameterTuple']
|
__all__ = ['Parameter', 'ParameterTuple']
|
||||||
|
|
||||||
|
@ -250,9 +250,11 @@ class Parameter:
|
||||||
raise ValueError("The length of layout must be 3! layout is {}."
|
raise ValueError("The length of layout must be 3! layout is {}."
|
||||||
.format(layout))
|
.format(layout))
|
||||||
self.init_mode.shape = layout[2]
|
self.init_mode.shape = layout[2]
|
||||||
self.init_mode.seed = int(_get_seed(layout[0], layout[1]))
|
slice_index = int(_get_slice_index(layout[0], layout[1]))
|
||||||
|
self.default_input = self.init_mode.to_tensor(slice_index)
|
||||||
|
else:
|
||||||
self.default_input = self.init_mode.to_tensor()
|
self.default_input = self.init_mode.to_tensor()
|
||||||
|
|
||||||
self.init_mode = None
|
self.init_mode = None
|
||||||
if set_sliced:
|
if set_sliced:
|
||||||
self.sliced = True
|
self.sliced = True
|
||||||
|
|
|
@ -168,21 +168,21 @@ def _chunk_tensor_by_strategy(np_tensor, strategy):
|
||||||
raise ValueError("The length of np_tensor does not match the length of strategy!")
|
raise ValueError("The length of np_tensor does not match the length of strategy!")
|
||||||
return _chunk_tensor(np_tensor, strategy, len(strategy))
|
return _chunk_tensor(np_tensor, strategy, len(strategy))
|
||||||
|
|
||||||
def _get_seed(dev_mat, tensor_map):
|
def _get_slice_index(dev_mat, tensor_map):
|
||||||
"""
|
"""
|
||||||
Get the random seed for current slice.
|
Get the slice index for current slice.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dev_mat (list): The device matrix of devices.
|
dev_mat (list): The device matrix of devices.
|
||||||
tensor_map (list): The split strategy of tensor.
|
tensor_map (list): The split strategy of tensor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Integer, the local random seed for this device.
|
Integer, the slice index for slice on this device.
|
||||||
"""
|
"""
|
||||||
rank = get_rank()
|
rank = get_rank()
|
||||||
tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
|
tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
|
||||||
tensor_slice_seed = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank)
|
tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank)
|
||||||
return tensor_slice_seed
|
return tensor_slice_index
|
||||||
|
|
||||||
def _load_tensor(tensor, dev_mat, tensor_map):
|
def _load_tensor(tensor, dev_mat, tensor_map):
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue