forked from mindspore-Ecosystem/mindspore
!15907 add type convert during load checkpoint
From: @changzherui Reviewed-by: @kingxian,@zh_qh Signed-off-by: @kingxian
This commit is contained in:
commit
c13ed20dad
|
@ -82,23 +82,26 @@ def _special_process_par(par, new_par):
|
|||
return False
|
||||
|
||||
|
||||
def _update_param(param, new_param):
|
||||
def _update_param(param, new_param, strict_load):
|
||||
"""Updates param's data from new_param's data."""
|
||||
|
||||
if isinstance(param.data, Tensor) and isinstance(new_param.data, Tensor):
|
||||
if param.data.dtype != new_param.data.dtype:
|
||||
logger.error("Failed to combine the net and the parameters for param %s.", param.name)
|
||||
msg = ("Net parameters {} type({}) different from parameter_dict's({})"
|
||||
.format(param.name, param.data.dtype, new_param.data.dtype))
|
||||
raise RuntimeError(msg)
|
||||
|
||||
if param.data.shape != new_param.data.shape:
|
||||
if not _special_process_par(param, new_param):
|
||||
logger.error("Failed to combine the net and the parameters for param %s.", param.name)
|
||||
msg = ("Net parameters {} shape({}) different from parameter_dict's({})"
|
||||
.format(param.name, param.data.shape, new_param.data.shape))
|
||||
raise RuntimeError(msg)
|
||||
return
|
||||
|
||||
if param.data.dtype != new_param.data.dtype:
|
||||
if _type_convert(param, new_param, strict_load):
|
||||
new_tensor = Tensor(new_param.data.asnumpy(), param.data.dtype)
|
||||
param.set_data(new_tensor)
|
||||
return
|
||||
|
||||
logger.error("Failed to combine the net and the parameters for param %s.", param.name)
|
||||
msg = ("Net parameters {} type({}) different from parameter_dict's({})"
|
||||
.format(param.name, param.data.dtype, new_param.data.dtype))
|
||||
raise RuntimeError(msg)
|
||||
|
||||
param.set_data(new_param.data)
|
||||
return
|
||||
|
@ -121,11 +124,21 @@ def _update_param(param, new_param):
|
|||
param.set_data(type(param.data)(new_param.data))
|
||||
|
||||
|
||||
def _type_convert(param, new_param, strict_load):
|
||||
"""Whether to convert parameter's type during load checkpoint into network."""
|
||||
float_type = (mstype.float16, mstype.float32, mstype.float64)
|
||||
int_type = (mstype.int8, mstype.int16, mstype.int32, mstype.int64)
|
||||
if not strict_load and ({param.data.dtype, new_param.data.dtype}.issubset(float_type) or
|
||||
{param.data.dtype, new_param.data.dtype}.issubset(int_type)):
|
||||
logger.warning("ckpt_dict parameter: {}'s type is {}, convert to {} in the network."
|
||||
.format(new_param.name, new_param.data.dtype, param.data.dtype))
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"):
|
||||
"""Execute the process of saving checkpoint into file."""
|
||||
|
||||
try:
|
||||
MAX_BLOCK_SIZE = 1024*1024*512
|
||||
with _ckpt_mutex:
|
||||
if os.path.exists(ckpt_file_name):
|
||||
os.remove(ckpt_file_name)
|
||||
|
@ -155,10 +168,10 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"):
|
|||
f.write(checkpoint_list.SerializeToString())
|
||||
else:
|
||||
plain_data += checkpoint_list.SerializeToString()
|
||||
while len(plain_data) >= MAX_BLOCK_SIZE:
|
||||
cipher_data += _encrypt(plain_data[0: MAX_BLOCK_SIZE], MAX_BLOCK_SIZE, enc_key,
|
||||
while len(plain_data) >= SLICE_SIZE * 1024:
|
||||
cipher_data += _encrypt(plain_data[0: SLICE_SIZE*1024], SLICE_SIZE*1024, enc_key,
|
||||
len(enc_key), enc_mode)
|
||||
plain_data = plain_data[MAX_BLOCK_SIZE:]
|
||||
plain_data = plain_data[SLICE_SIZE*1024:]
|
||||
|
||||
if enc_key is not None:
|
||||
if plain_data:
|
||||
|
@ -310,7 +323,8 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|||
ckpt_file_name (str): Checkpoint file name.
|
||||
net (Cell): Cell network. Default: None
|
||||
strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter
|
||||
in the param_dict into net with the same suffix. Default: False
|
||||
in the param_dict into net with the same suffix and load
|
||||
parameter with different accuracy. Default: False.
|
||||
filter_prefix (Union[str, list[str], tuple[str]]): Parameters starting with the filter_prefix
|
||||
will not be loaded. Default: None.
|
||||
dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is None, the decryption
|
||||
|
@ -469,12 +483,12 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|||
logger.error("Failed to combine the net and the parameters.")
|
||||
msg = ("Argument parameter_dict element should be a Parameter, but got {}.".format(type(new_param)))
|
||||
raise TypeError(msg)
|
||||
_update_param(param, new_param)
|
||||
_update_param(param, new_param, strict_load)
|
||||
else:
|
||||
param_not_load.append(param.name)
|
||||
|
||||
if param_not_load and not strict_load:
|
||||
_load_dismatch_prefix_params(net, parameter_dict, param_not_load)
|
||||
_load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load)
|
||||
|
||||
logger.debug("Params not matched(in net but not in parameter_dict):")
|
||||
for param_name in param_not_load:
|
||||
|
@ -486,7 +500,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|||
return param_not_load
|
||||
|
||||
|
||||
def _load_dismatch_prefix_params(net, parameter_dict, param_not_load):
|
||||
def _load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load):
|
||||
"""When some net parameter did not load, try to continue load."""
|
||||
prefix_name = ""
|
||||
longest_name = param_not_load[0]
|
||||
|
@ -507,7 +521,7 @@ def _load_dismatch_prefix_params(net, parameter_dict, param_not_load):
|
|||
new_param_name = prefix_name + param.name
|
||||
if param.name in param_not_load and new_param_name in parameter_dict:
|
||||
new_param = parameter_dict[new_param_name]
|
||||
_update_param(param, new_param)
|
||||
_update_param(param, new_param, strict_load)
|
||||
param_not_load.remove(param.name)
|
||||
|
||||
|
||||
|
|
|
@ -236,7 +236,7 @@ def test_load_param_into_net_param_type_and_shape_error():
|
|||
assert net.conv1.weight.data.asnumpy()[0][0][0][0] == 0
|
||||
|
||||
parameter_dict = {}
|
||||
one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7))), name="conv1.weight")
|
||||
one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.int32), name="conv1.weight")
|
||||
parameter_dict["conv1.weight"] = one_param
|
||||
with pytest.raises(RuntimeError):
|
||||
load_param_into_net(net, parameter_dict)
|
||||
|
|
Loading…
Reference in New Issue