forked from mindspore-Ecosystem/mindspore
!2392 don't change shape of Initializer when init slice of a Parameter
Merge pull request !2392 from yihuaijie/master
This commit is contained in:
commit
f3f95b255b
|
@ -64,7 +64,7 @@ class Initializer:
|
|||
def dtype(self, dtype):
|
||||
self._dtype = dtype
|
||||
|
||||
def to_tensor(self, slice_index=None):
|
||||
def to_tensor(self, slice_index=None, shape=None):
|
||||
"""
|
||||
Get the tensor format data of this Initializer.
|
||||
|
||||
|
@ -72,12 +72,16 @@ class Initializer:
|
|||
slice_index (int): Slice index of a parameter's slices.
|
||||
Used when initialize a slice of a parameter, it guarantee that
|
||||
devices use the same slice can generate the same tensor.
|
||||
shape (list[int]): Shape of the slice, used when initialize a slice of the parameter.
|
||||
"""
|
||||
arr = None
|
||||
if shape is None:
|
||||
shape = self.shape
|
||||
|
||||
try:
|
||||
arr = np.ndarray(self.shape)
|
||||
arr = np.ndarray(shape)
|
||||
except ValueError:
|
||||
msg = "Error shape={}".format(self.shape)
|
||||
msg = "Error shape={}".format(shape)
|
||||
logger.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
|
|
|
@ -249,9 +249,8 @@ class Parameter:
|
|||
if len(layout) != 3:
|
||||
raise ValueError("The length of layout must be 3! layout is {}."
|
||||
.format(layout))
|
||||
self.init_mode.shape = layout[2]
|
||||
slice_index = int(_get_slice_index(layout[0], layout[1]))
|
||||
self.default_input = self.init_mode.to_tensor(slice_index)
|
||||
self.default_input = self.init_mode.to_tensor(slice_index, layout[2])
|
||||
else:
|
||||
self.default_input = self.init_mode.to_tensor()
|
||||
|
||||
|
|
Loading…
Reference in New Issue