forked from mindspore-Ecosystem/mindspore
don't change shape of Initializer when init slice of a Parameter
This commit is contained in:
parent
634bfd3508
commit
27b5bc6d95
|
@ -64,7 +64,7 @@ class Initializer:
|
||||||
def dtype(self, dtype):
|
def dtype(self, dtype):
|
||||||
self._dtype = dtype
|
self._dtype = dtype
|
||||||
|
|
||||||
def to_tensor(self, slice_index=None):
|
def to_tensor(self, slice_index=None, shape=None):
|
||||||
"""
|
"""
|
||||||
Get the tensor format data of this Initializer.
|
Get the tensor format data of this Initializer.
|
||||||
|
|
||||||
|
@ -72,12 +72,16 @@ class Initializer:
|
||||||
slice_index (int): Slice index of a parameter's slices.
|
slice_index (int): Slice index of a parameter's slices.
|
||||||
Used when initialize a slice of a parameter, it guarantee that
|
Used when initialize a slice of a parameter, it guarantee that
|
||||||
devices use the same slice can generate the same tensor.
|
devices use the same slice can generate the same tensor.
|
||||||
|
shape (list[int]): Shape of the slice, used when initialize a slice of the parameter.
|
||||||
"""
|
"""
|
||||||
arr = None
|
arr = None
|
||||||
|
if shape is None:
|
||||||
|
shape = self.shape
|
||||||
|
|
||||||
try:
|
try:
|
||||||
arr = np.ndarray(self.shape)
|
arr = np.ndarray(shape)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
msg = "Error shape={}".format(self.shape)
|
msg = "Error shape={}".format(shape)
|
||||||
logger.error(msg)
|
logger.error(msg)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
|
|
@ -249,9 +249,8 @@ class Parameter:
|
||||||
if len(layout) != 3:
|
if len(layout) != 3:
|
||||||
raise ValueError("The length of layout must be 3! layout is {}."
|
raise ValueError("The length of layout must be 3! layout is {}."
|
||||||
.format(layout))
|
.format(layout))
|
||||||
self.init_mode.shape = layout[2]
|
|
||||||
slice_index = int(_get_slice_index(layout[0], layout[1]))
|
slice_index = int(_get_slice_index(layout[0], layout[1]))
|
||||||
self.default_input = self.init_mode.to_tensor(slice_index)
|
self.default_input = self.init_mode.to_tensor(slice_index, layout[2])
|
||||||
else:
|
else:
|
||||||
self.default_input = self.init_mode.to_tensor()
|
self.default_input = self.init_mode.to_tensor()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue