diff --git a/mindspore/common/initializer.py b/mindspore/common/initializer.py index a50c88ec9a3..b83689f8825 100644 --- a/mindspore/common/initializer.py +++ b/mindspore/common/initializer.py @@ -41,7 +41,6 @@ class Initializer: self._kwargs = kwargs self.shape = None self.dtype = None - self._seed = None def _initialize(self, *kwargs): raise NotImplementedError('Must be overridden!') @@ -49,15 +48,6 @@ class Initializer: def __call__(self, arr): return self._initialize(arr) - @property - def seed(self): - return self._seed - - @seed.setter - def seed(self, seed_): - """set the random seed.""" - self._seed = seed_ - @property def shape(self): return self._shape @@ -74,8 +64,15 @@ class Initializer: def dtype(self, dtype): self._dtype = dtype - def to_tensor(self): - """Get the tensor format data of this Initializer.""" + def to_tensor(self, slice_index=None): + """ + Get the tensor format data of this Initializer. + + Args: + slice_index (int): Slice index of a parameter's slices. + Used when initialize a slice of a parameter, it guarantee that + devices use the same slice can generate the same tensor. + """ arr = None try: arr = np.ndarray(self.shape) @@ -83,10 +80,10 @@ class Initializer: msg = "Error shape={}".format(self.shape) logger.error(msg) raise ValueError(msg) - if self._seed is not None: - np.random.seed(self.seed) + + if slice_index is not None: + np.random.seed(slice_index) self.__call__(arr) - self._seed = None return Tensor(arr, dtype=self.dtype) def _register(*aliases): diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index a8e89d2784e..69affee2c32 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -22,7 +22,7 @@ from .initializer import initializer, Initializer from .tensor import Tensor, MetaTensor from .._checkparam import _check_str_by_regular from ..parallel._utils import _set_clone_info, _CloneInfo -from ..parallel._tensor import _get_seed +from ..parallel._tensor import _get_slice_index __all__ = ['Parameter', 'ParameterTuple'] @@ -250,9 +250,11 @@ class Parameter: raise ValueError("The length of layout must be 3! layout is {}." .format(layout)) self.init_mode.shape = layout[2] - self.init_mode.seed = int(_get_seed(layout[0], layout[1])) + slice_index = int(_get_slice_index(layout[0], layout[1])) + self.default_input = self.init_mode.to_tensor(slice_index) + else: + self.default_input = self.init_mode.to_tensor() - self.default_input = self.init_mode.to_tensor() self.init_mode = None if set_sliced: self.sliced = True diff --git a/mindspore/parallel/_tensor.py b/mindspore/parallel/_tensor.py index 073ad9809af..fca8b889201 100644 --- a/mindspore/parallel/_tensor.py +++ b/mindspore/parallel/_tensor.py @@ -168,21 +168,21 @@ def _chunk_tensor_by_strategy(np_tensor, strategy): raise ValueError("The length of np_tensor does not match the length of strategy!") return _chunk_tensor(np_tensor, strategy, len(strategy)) -def _get_seed(dev_mat, tensor_map): +def _get_slice_index(dev_mat, tensor_map): """ - Get the random seed for current slice. + Get the slice index for current slice. Args: dev_mat (list): The device matrix of devices. tensor_map (list): The split strategy of tensor. Returns: - Integer, the local random seed for this device. + Integer, the slice index for slice on this device. """ rank = get_rank() tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map) - tensor_slice_seed = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank) - return tensor_slice_seed + tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank) + return tensor_slice_index def _load_tensor(tensor, dev_mat, tensor_map): """