!6781 Change prefix for server ckpt callback

Merge pull request !6781 from ZPaC/master-change-prefix-for-server-ckpt
This commit is contained in:
mindspore-ci-bot 2020-09-24 10:28:18 +08:00 committed by Gitee
commit 7f390467e9
2 changed files with 6 additions and 5 deletions

View File

@ -95,10 +95,11 @@ def _get_ps_context(attr_key):
Raises:
ValueError: If input key is not attribute in auto parallel context.
"""
if key not in _get_ps_context_func_map:
raise ValueError("Get PS context keyword %s is not recognized!" % key)
if attr_key not in _get_ps_context_func_map:
raise ValueError("Get PS context keyword %s is not recognized!" % attr_key)
get_func = _get_ps_context_func_map[attr_key]
get_func(attr_key)
value = get_func()
return value
def _reset_ps_context():
"""

View File

@ -228,6 +228,8 @@ class ModelCheckpoint(Callback):
Args:
run_context (RunContext): Context of the train running.
"""
if _is_role_pserver():
self._prefix = "PServer_" + str(_get_ps_mode_rank()) + "_" + self._prefix
cb_params = run_context.original_args()
# save graph (only once)
if not self._graph_saved:
@ -281,8 +283,6 @@ class ModelCheckpoint(Callback):
if save_ckpt:
cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \
+ str(step_num_in_epoch) + ".ckpt"
if _is_role_pserver():
cur_ckpoint_file = "PServer_" + str(_get_ps_mode_rank()) + "_" + cur_ckpoint_file
# update checkpoint file list.
self._manager.update_ckpoint_filelist(self._directory, self._prefix)
# keep checkpoint files number equal max number.