diff --git a/mindspore/common/initializer.py b/mindspore/common/initializer.py index b83689f8825..83586272ee2 100644 --- a/mindspore/common/initializer.py +++ b/mindspore/common/initializer.py @@ -64,7 +64,7 @@ class Initializer: def dtype(self, dtype): self._dtype = dtype - def to_tensor(self, slice_index=None): + def to_tensor(self, slice_index=None, shape=None): """ Get the tensor format data of this Initializer. @@ -72,12 +72,16 @@ class Initializer: slice_index (int): Slice index of a parameter's slices. Used when initialize a slice of a parameter, it guarantee that devices use the same slice can generate the same tensor. + shape (list[int]): Shape of the slice, used when initialize a slice of the parameter. """ arr = None + if shape is None: + shape = self.shape + try: - arr = np.ndarray(self.shape) + arr = np.ndarray(shape) except ValueError: - msg = "Error shape={}".format(self.shape) + msg = "Error shape={}".format(shape) logger.error(msg) raise ValueError(msg) diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index 69affee2c32..4e8cf288885 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -249,9 +249,8 @@ class Parameter: if len(layout) != 3: raise ValueError("The length of layout must be 3! layout is {}." .format(layout)) - self.init_mode.shape = layout[2] slice_index = int(_get_slice_index(layout[0], layout[1])) - self.default_input = self.init_mode.to_tensor(slice_index) + self.default_input = self.init_mode.to_tensor(slice_index, layout[2]) else: self.default_input = self.init_mode.to_tensor()