!15907 add type convert during load checkpoint

From: @changzherui
Reviewed-by: @kingxian,@zh_qh
Signed-off-by: @kingxian
This commit is contained in:
mindspore-ci-bot 2021-04-30 15:52:41 +08:00 committed by Gitee
commit c13ed20dad
2 changed files with 34 additions and 20 deletions

View File

@ -82,23 +82,26 @@ def _special_process_par(par, new_par):
return False
def _update_param(param, new_param):
def _update_param(param, new_param, strict_load):
"""Updates param's data from new_param's data."""
if isinstance(param.data, Tensor) and isinstance(new_param.data, Tensor):
if param.data.dtype != new_param.data.dtype:
logger.error("Failed to combine the net and the parameters for param %s.", param.name)
msg = ("Net parameters {} type({}) different from parameter_dict's({})"
.format(param.name, param.data.dtype, new_param.data.dtype))
raise RuntimeError(msg)
if param.data.shape != new_param.data.shape:
if not _special_process_par(param, new_param):
logger.error("Failed to combine the net and the parameters for param %s.", param.name)
msg = ("Net parameters {} shape({}) different from parameter_dict's({})"
.format(param.name, param.data.shape, new_param.data.shape))
raise RuntimeError(msg)
return
if param.data.dtype != new_param.data.dtype:
if _type_convert(param, new_param, strict_load):
new_tensor = Tensor(new_param.data.asnumpy(), param.data.dtype)
param.set_data(new_tensor)
return
logger.error("Failed to combine the net and the parameters for param %s.", param.name)
msg = ("Net parameters {} type({}) different from parameter_dict's({})"
.format(param.name, param.data.dtype, new_param.data.dtype))
raise RuntimeError(msg)
param.set_data(new_param.data)
return
@ -121,11 +124,21 @@ def _update_param(param, new_param):
param.set_data(type(param.data)(new_param.data))
def _type_convert(param, new_param, strict_load):
"""Whether to convert parameter's type during load checkpoint into network."""
float_type = (mstype.float16, mstype.float32, mstype.float64)
int_type = (mstype.int8, mstype.int16, mstype.int32, mstype.int64)
if not strict_load and ({param.data.dtype, new_param.data.dtype}.issubset(float_type) or
{param.data.dtype, new_param.data.dtype}.issubset(int_type)):
logger.warning("ckpt_dict parameter: {}'s type is {}, convert to {} in the network."
.format(new_param.name, new_param.data.dtype, param.data.dtype))
return True
return False
def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"):
"""Execute the process of saving checkpoint into file."""
try:
MAX_BLOCK_SIZE = 1024*1024*512
with _ckpt_mutex:
if os.path.exists(ckpt_file_name):
os.remove(ckpt_file_name)
@ -155,10 +168,10 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"):
f.write(checkpoint_list.SerializeToString())
else:
plain_data += checkpoint_list.SerializeToString()
while len(plain_data) >= MAX_BLOCK_SIZE:
cipher_data += _encrypt(plain_data[0: MAX_BLOCK_SIZE], MAX_BLOCK_SIZE, enc_key,
while len(plain_data) >= SLICE_SIZE * 1024:
cipher_data += _encrypt(plain_data[0: SLICE_SIZE*1024], SLICE_SIZE*1024, enc_key,
len(enc_key), enc_mode)
plain_data = plain_data[MAX_BLOCK_SIZE:]
plain_data = plain_data[SLICE_SIZE*1024:]
if enc_key is not None:
if plain_data:
@ -310,7 +323,8 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
ckpt_file_name (str): Checkpoint file name.
net (Cell): Cell network. Default: None
strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter
in the param_dict into net with the same suffix. Default: False
in the param_dict into net with the same suffix and load
parameter with different accuracy. Default: False.
filter_prefix (Union[str, list[str], tuple[str]]): Parameters starting with the filter_prefix
will not be loaded. Default: None.
dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is None, the decryption
@ -469,12 +483,12 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
logger.error("Failed to combine the net and the parameters.")
msg = ("Argument parameter_dict element should be a Parameter, but got {}.".format(type(new_param)))
raise TypeError(msg)
_update_param(param, new_param)
_update_param(param, new_param, strict_load)
else:
param_not_load.append(param.name)
if param_not_load and not strict_load:
_load_dismatch_prefix_params(net, parameter_dict, param_not_load)
_load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load)
logger.debug("Params not matched(in net but not in parameter_dict):")
for param_name in param_not_load:
@ -486,7 +500,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
return param_not_load
def _load_dismatch_prefix_params(net, parameter_dict, param_not_load):
def _load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load):
"""When some net parameter did not load, try to continue load."""
prefix_name = ""
longest_name = param_not_load[0]
@ -507,7 +521,7 @@ def _load_dismatch_prefix_params(net, parameter_dict, param_not_load):
new_param_name = prefix_name + param.name
if param.name in param_not_load and new_param_name in parameter_dict:
new_param = parameter_dict[new_param_name]
_update_param(param, new_param)
_update_param(param, new_param, strict_load)
param_not_load.remove(param.name)

View File

@ -236,7 +236,7 @@ def test_load_param_into_net_param_type_and_shape_error():
assert net.conv1.weight.data.asnumpy()[0][0][0][0] == 0
parameter_dict = {}
one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7))), name="conv1.weight")
one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.int32), name="conv1.weight")
parameter_dict["conv1.weight"] = one_param
with pytest.raises(RuntimeError):
load_param_into_net(net, parameter_dict)