forked from mindspore-Ecosystem/mindspore
!9465 optimize verb description
From: @caozhou_huawei Reviewed-by: @kingxian,@zhunaipan Signed-off-by: @kingxian
This commit is contained in:
commit
5d3ada2188
|
@ -108,7 +108,7 @@ def _update_param(param, new_param):
|
||||||
|
|
||||||
|
|
||||||
def _exec_save(ckpt_file_name, data_list):
|
def _exec_save(ckpt_file_name, data_list):
|
||||||
"""Execute save checkpoint into file process."""
|
"""Execute the process of saving checkpoint into file."""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with _ckpt_mutex:
|
with _ckpt_mutex:
|
||||||
|
@ -163,7 +163,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F
|
||||||
integrated_save = Validator.check_bool(integrated_save)
|
integrated_save = Validator.check_bool(integrated_save)
|
||||||
async_save = Validator.check_bool(async_save)
|
async_save = Validator.check_bool(async_save)
|
||||||
|
|
||||||
logger.info("Execute save checkpoint process.")
|
logger.info("Execute the process of saving checkpoint.")
|
||||||
|
|
||||||
if isinstance(save_obj, nn.Cell):
|
if isinstance(save_obj, nn.Cell):
|
||||||
save_obj.init_parameters_data()
|
save_obj.init_parameters_data()
|
||||||
|
@ -209,7 +209,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F
|
||||||
else:
|
else:
|
||||||
_exec_save(ckpt_file_name, data_list)
|
_exec_save(ckpt_file_name, data_list)
|
||||||
|
|
||||||
logger.info("Save checkpoint process finish.")
|
logger.info("Saving checkpoint process finished.")
|
||||||
|
|
||||||
|
|
||||||
def _check_param_prefix(filter_prefix, param_name):
|
def _check_param_prefix(filter_prefix, param_name):
|
||||||
|
@ -268,7 +268,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
||||||
raise TypeError(f"The type of filter_prefix must be str, list[str] or tuple[str], "
|
raise TypeError(f"The type of filter_prefix must be str, list[str] or tuple[str], "
|
||||||
f"but got {str(type(prefix))} at index {index}.")
|
f"but got {str(type(prefix))} at index {index}.")
|
||||||
|
|
||||||
logger.info("Execute load checkpoint process.")
|
logger.info("Execute the process of loading checkpoint.")
|
||||||
checkpoint_list = Checkpoint()
|
checkpoint_list = Checkpoint()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -312,7 +312,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
||||||
param_value = param_data.reshape(param_dim)
|
param_value = param_data.reshape(param_dim)
|
||||||
parameter_dict[element.tag] = Parameter(Tensor(param_value, ms_type), name=element.tag)
|
parameter_dict[element.tag] = Parameter(Tensor(param_value, ms_type), name=element.tag)
|
||||||
|
|
||||||
logger.info("Load checkpoint process finish.")
|
logger.info("Loading checkpoint process finished.")
|
||||||
|
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name)
|
logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name)
|
||||||
|
@ -357,7 +357,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
||||||
raise TypeError(msg)
|
raise TypeError(msg)
|
||||||
|
|
||||||
strict_load = Validator.check_bool(strict_load)
|
strict_load = Validator.check_bool(strict_load)
|
||||||
logger.info("Execute load parameter into net process.")
|
logger.info("Execute the process of loading parameter into net.")
|
||||||
net.init_parameters_data()
|
net.init_parameters_data()
|
||||||
param_not_load = []
|
param_not_load = []
|
||||||
for _, param in net.parameters_and_names():
|
for _, param in net.parameters_and_names():
|
||||||
|
@ -378,7 +378,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
||||||
for param_name in param_not_load:
|
for param_name in param_not_load:
|
||||||
logger.debug("%s", param_name)
|
logger.debug("%s", param_name)
|
||||||
|
|
||||||
logger.info("Load parameter into net finish.")
|
logger.info("Loading parameter into net finished.")
|
||||||
if param_not_load:
|
if param_not_load:
|
||||||
logger.warning("{} parameters in the net are not loaded.".format(len(param_not_load)))
|
logger.warning("{} parameters in the net are not loaded.".format(len(param_not_load)))
|
||||||
return param_not_load
|
return param_not_load
|
||||||
|
@ -417,7 +417,7 @@ def _save_graph(network, file_name):
|
||||||
network (Cell): Obtain a pipeline through network for saving graph.
|
network (Cell): Obtain a pipeline through network for saving graph.
|
||||||
file_name (str): Graph file name into which the graph will be saved.
|
file_name (str): Graph file name into which the graph will be saved.
|
||||||
"""
|
"""
|
||||||
logger.info("Execute save the graph process.")
|
logger.info("Execute the process of saving graph.")
|
||||||
|
|
||||||
graph_proto = network.get_func_graph_proto()
|
graph_proto = network.get_func_graph_proto()
|
||||||
if graph_proto:
|
if graph_proto:
|
||||||
|
|
Loading…
Reference in New Issue