!30322 modify set_data for parameter

Merge pull request !30322 from lilei/modify_initializer_for_master
This commit is contained in:
i-robot 2022-02-24 08:14:30 +00:00 committed by Gitee
commit 14421b139b
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 23 additions and 16 deletions

View File

@ -516,6 +516,23 @@ class Parameter(Tensor_):
raise TypeError("`stage` must be a positive number of int type")
self._pipeline_stage_list.append(stage)
def _raise_type_error(self, incoming):
raise TypeError(f"Incoming Parameter dtype can not be converted to current dtype implicitly. "
f"Current dtype is {self.dtype}, and incoming is {incoming}. "
f"Use .set_dtype(xxx) to change the dtype.")
@staticmethod
def _set_data_check_input_valid(current_shape, data_shape, current_tensor_is_init,
incoming_tensor_is_init, slice_shape=False):
if incoming_tensor_is_init and not current_tensor_is_init:
raise TypeError("The original tensor data is initialized, but the argument 'data' is not initialized."
"Please initialize 'data' before call this method.")
if tuple(current_shape) != tuple(data_shape):
# If Slice create Parameter shape can be change.
if not slice_shape:
raise ValueError(f"Can not change the shape of Parameter which has been initialized."
f" Current shape is {current_shape}, and incoming is {data_shape}.")
def set_data(self, data, slice_shape=False):
"""
Set Parameter's data.
@ -528,35 +545,25 @@ class Parameter(Tensor_):
Returns:
Parameter, the parameter after set data.
"""
def raise_type_error(incoming):
raise TypeError(f"Incoming Parameter dtype can not be converted to current dtype implicitly. "
f"Current dtype is {self.dtype}, and incoming is {incoming}. "
f"Use .set_dtype(xxx) to change the dtype.")
if not isinstance(data, (Tensor, int, float)):
raise TypeError(f"Parameter data must be [`Tensor`, `int`, `float`] or a kind of `Tensor` "
f"(like `Tensor`). But with type {type(data)}.")
if isinstance(data, (int, float)):
if self.dtype in mstype.int_type and isinstance(data, float):
raise_type_error(mstype.float_)
self._raise_type_error(mstype.float_)
data = Tensor(data, self.dtype)
# both not init.
incoming_tensor_is_init = isinstance(data, Tensor) and not data.has_init
current_tensor_is_init = isinstance(self, Tensor) and not self.has_init
if incoming_tensor_is_init and not current_tensor_is_init:
raise TypeError("The original tensor data is initialized, but the argument 'data' is not initialized."
"Please initialize 'data' before call this method.")
if tuple(self.shape) != tuple(data.shape):
# If Slice create Parameter shape can be change.
if not slice_shape:
raise ValueError(f"Can not change the shape of Parameter which has been initialized."
f" Current shape is {self.shape}, and incoming is {data.shape}.")
Parameter._set_data_check_input_valid(self.shape, data.shape, current_tensor_is_init, incoming_tensor_is_init,
slice_shape)
if self.dtype != data.dtype:
if mstype.implicit_conversion_seq[self.dtype] < mstype.implicit_conversion_seq[data.dtype]:
raise_type_error(data.dtype)
self._raise_type_error(data.dtype)
else:
from mindspore.ops import functional as F
if isinstance(data, Tensor) and data.init is not None:
data.init_data()
data = F.cast(data, self.dtype)
if isinstance(data, Tensor) and data.has_init:
# The parameter has been initialized, directly update by the data