don't change shape of Initializer when init slice of a Parameter

This commit is contained in:
Yi Huaijie 2020-06-20 10:09:25 +08:00
parent 634bfd3508
commit 27b5bc6d95
2 changed files with 8 additions and 5 deletions

View File

@ -64,7 +64,7 @@ class Initializer:
def dtype(self, dtype): def dtype(self, dtype):
self._dtype = dtype self._dtype = dtype
def to_tensor(self, slice_index=None): def to_tensor(self, slice_index=None, shape=None):
""" """
Get the tensor format data of this Initializer. Get the tensor format data of this Initializer.
@ -72,12 +72,16 @@ class Initializer:
slice_index (int): Slice index of a parameter's slices. slice_index (int): Slice index of a parameter's slices.
Used when initialize a slice of a parameter, it guarantee that Used when initialize a slice of a parameter, it guarantee that
devices use the same slice can generate the same tensor. devices use the same slice can generate the same tensor.
shape (list[int]): Shape of the slice, used when initialize a slice of the parameter.
""" """
arr = None arr = None
if shape is None:
shape = self.shape
try: try:
arr = np.ndarray(self.shape) arr = np.ndarray(shape)
except ValueError: except ValueError:
msg = "Error shape={}".format(self.shape) msg = "Error shape={}".format(shape)
logger.error(msg) logger.error(msg)
raise ValueError(msg) raise ValueError(msg)

View File

@ -249,9 +249,8 @@ class Parameter:
if len(layout) != 3: if len(layout) != 3:
raise ValueError("The length of layout must be 3! layout is {}." raise ValueError("The length of layout must be 3! layout is {}."
.format(layout)) .format(layout))
self.init_mode.shape = layout[2]
slice_index = int(_get_slice_index(layout[0], layout[1])) slice_index = int(_get_slice_index(layout[0], layout[1]))
self.default_input = self.init_mode.to_tensor(slice_index) self.default_input = self.init_mode.to_tensor(slice_index, layout[2])
else: else:
self.default_input = self.init_mode.to_tensor() self.default_input = self.init_mode.to_tensor()