!9465 optimize verb description

From: @caozhou_huawei
Reviewed-by: @kingxian,@zhunaipan
Signed-off-by: @kingxian
This commit is contained in:
mindspore-ci-bot 2020-12-07 09:29:20 +08:00 committed by Gitee
commit 5d3ada2188
1 changed files with 8 additions and 8 deletions

View File

@ -108,7 +108,7 @@ def _update_param(param, new_param):
def _exec_save(ckpt_file_name, data_list): def _exec_save(ckpt_file_name, data_list):
"""Execute save checkpoint into file process.""" """Execute the process of saving checkpoint into file."""
try: try:
with _ckpt_mutex: with _ckpt_mutex:
@ -163,7 +163,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F
integrated_save = Validator.check_bool(integrated_save) integrated_save = Validator.check_bool(integrated_save)
async_save = Validator.check_bool(async_save) async_save = Validator.check_bool(async_save)
logger.info("Execute save checkpoint process.") logger.info("Execute the process of saving checkpoint.")
if isinstance(save_obj, nn.Cell): if isinstance(save_obj, nn.Cell):
save_obj.init_parameters_data() save_obj.init_parameters_data()
@ -209,7 +209,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F
else: else:
_exec_save(ckpt_file_name, data_list) _exec_save(ckpt_file_name, data_list)
logger.info("Save checkpoint process finish.") logger.info("Saving checkpoint process finished.")
def _check_param_prefix(filter_prefix, param_name): def _check_param_prefix(filter_prefix, param_name):
@ -268,7 +268,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
raise TypeError(f"The type of filter_prefix must be str, list[str] or tuple[str], " raise TypeError(f"The type of filter_prefix must be str, list[str] or tuple[str], "
f"but got {str(type(prefix))} at index {index}.") f"but got {str(type(prefix))} at index {index}.")
logger.info("Execute load checkpoint process.") logger.info("Execute the process of loading checkpoint.")
checkpoint_list = Checkpoint() checkpoint_list = Checkpoint()
try: try:
@ -312,7 +312,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
param_value = param_data.reshape(param_dim) param_value = param_data.reshape(param_dim)
parameter_dict[element.tag] = Parameter(Tensor(param_value, ms_type), name=element.tag) parameter_dict[element.tag] = Parameter(Tensor(param_value, ms_type), name=element.tag)
logger.info("Load checkpoint process finish.") logger.info("Loading checkpoint process finished.")
except BaseException as e: except BaseException as e:
logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name) logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name)
@ -357,7 +357,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
raise TypeError(msg) raise TypeError(msg)
strict_load = Validator.check_bool(strict_load) strict_load = Validator.check_bool(strict_load)
logger.info("Execute load parameter into net process.") logger.info("Execute the process of loading parameter into net.")
net.init_parameters_data() net.init_parameters_data()
param_not_load = [] param_not_load = []
for _, param in net.parameters_and_names(): for _, param in net.parameters_and_names():
@ -378,7 +378,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
for param_name in param_not_load: for param_name in param_not_load:
logger.debug("%s", param_name) logger.debug("%s", param_name)
logger.info("Load parameter into net finish.") logger.info("Loading parameter into net finished.")
if param_not_load: if param_not_load:
logger.warning("{} parameters in the net are not loaded.".format(len(param_not_load))) logger.warning("{} parameters in the net are not loaded.".format(len(param_not_load)))
return param_not_load return param_not_load
@ -417,7 +417,7 @@ def _save_graph(network, file_name):
network (Cell): Obtain a pipeline through network for saving graph. network (Cell): Obtain a pipeline through network for saving graph.
file_name (str): Graph file name into which the graph will be saved. file_name (str): Graph file name into which the graph will be saved.
""" """
logger.info("Execute save the graph process.") logger.info("Execute the process of saving graph.")
graph_proto = network.get_func_graph_proto() graph_proto = network.get_func_graph_proto()
if graph_proto: if graph_proto: