forked from mindspore-Ecosystem/mindspore
!30322 modify set_data for parameter
Merge pull request !30322 from lilei/modify_initializer_for_master
This commit is contained in:
commit
14421b139b
|
@ -516,6 +516,23 @@ class Parameter(Tensor_):
|
|||
raise TypeError("`stage` must be a positive number of int type")
|
||||
self._pipeline_stage_list.append(stage)
|
||||
|
||||
def _raise_type_error(self, incoming):
|
||||
raise TypeError(f"Incoming Parameter dtype can not be converted to current dtype implicitly. "
|
||||
f"Current dtype is {self.dtype}, and incoming is {incoming}. "
|
||||
f"Use .set_dtype(xxx) to change the dtype.")
|
||||
|
||||
@staticmethod
|
||||
def _set_data_check_input_valid(current_shape, data_shape, current_tensor_is_init,
|
||||
incoming_tensor_is_init, slice_shape=False):
|
||||
if incoming_tensor_is_init and not current_tensor_is_init:
|
||||
raise TypeError("The original tensor data is initialized, but the argument 'data' is not initialized."
|
||||
"Please initialize 'data' before call this method.")
|
||||
if tuple(current_shape) != tuple(data_shape):
|
||||
# If Slice create Parameter shape can be change.
|
||||
if not slice_shape:
|
||||
raise ValueError(f"Can not change the shape of Parameter which has been initialized."
|
||||
f" Current shape is {current_shape}, and incoming is {data_shape}.")
|
||||
|
||||
def set_data(self, data, slice_shape=False):
|
||||
"""
|
||||
Set Parameter's data.
|
||||
|
@ -528,35 +545,25 @@ class Parameter(Tensor_):
|
|||
Returns:
|
||||
Parameter, the parameter after set data.
|
||||
"""
|
||||
def raise_type_error(incoming):
|
||||
raise TypeError(f"Incoming Parameter dtype can not be converted to current dtype implicitly. "
|
||||
f"Current dtype is {self.dtype}, and incoming is {incoming}. "
|
||||
f"Use .set_dtype(xxx) to change the dtype.")
|
||||
|
||||
if not isinstance(data, (Tensor, int, float)):
|
||||
raise TypeError(f"Parameter data must be [`Tensor`, `int`, `float`] or a kind of `Tensor` "
|
||||
f"(like `Tensor`). But with type {type(data)}.")
|
||||
if isinstance(data, (int, float)):
|
||||
if self.dtype in mstype.int_type and isinstance(data, float):
|
||||
raise_type_error(mstype.float_)
|
||||
self._raise_type_error(mstype.float_)
|
||||
data = Tensor(data, self.dtype)
|
||||
# both not init.
|
||||
incoming_tensor_is_init = isinstance(data, Tensor) and not data.has_init
|
||||
current_tensor_is_init = isinstance(self, Tensor) and not self.has_init
|
||||
|
||||
if incoming_tensor_is_init and not current_tensor_is_init:
|
||||
raise TypeError("The original tensor data is initialized, but the argument 'data' is not initialized."
|
||||
"Please initialize 'data' before call this method.")
|
||||
if tuple(self.shape) != tuple(data.shape):
|
||||
# If Slice create Parameter shape can be change.
|
||||
if not slice_shape:
|
||||
raise ValueError(f"Can not change the shape of Parameter which has been initialized."
|
||||
f" Current shape is {self.shape}, and incoming is {data.shape}.")
|
||||
Parameter._set_data_check_input_valid(self.shape, data.shape, current_tensor_is_init, incoming_tensor_is_init,
|
||||
slice_shape)
|
||||
if self.dtype != data.dtype:
|
||||
if mstype.implicit_conversion_seq[self.dtype] < mstype.implicit_conversion_seq[data.dtype]:
|
||||
raise_type_error(data.dtype)
|
||||
self._raise_type_error(data.dtype)
|
||||
else:
|
||||
from mindspore.ops import functional as F
|
||||
if isinstance(data, Tensor) and data.init is not None:
|
||||
data.init_data()
|
||||
data = F.cast(data, self.dtype)
|
||||
if isinstance(data, Tensor) and data.has_init:
|
||||
# The parameter has been initialized, directly update by the data
|
||||
|
|
Loading…
Reference in New Issue