forked from mindspore-Ecosystem/mindspore
!6781 Change prefix for server ckpt callback
Merge pull request !6781 from ZPaC/master-change-prefix-for-server-ckpt
This commit is contained in:
commit
7f390467e9
|
@ -95,10 +95,11 @@ def _get_ps_context(attr_key):
|
|||
Raises:
|
||||
ValueError: If input key is not attribute in auto parallel context.
|
||||
"""
|
||||
if key not in _get_ps_context_func_map:
|
||||
raise ValueError("Get PS context keyword %s is not recognized!" % key)
|
||||
if attr_key not in _get_ps_context_func_map:
|
||||
raise ValueError("Get PS context keyword %s is not recognized!" % attr_key)
|
||||
get_func = _get_ps_context_func_map[attr_key]
|
||||
get_func(attr_key)
|
||||
value = get_func()
|
||||
return value
|
||||
|
||||
def _reset_ps_context():
|
||||
"""
|
||||
|
|
|
@ -228,6 +228,8 @@ class ModelCheckpoint(Callback):
|
|||
Args:
|
||||
run_context (RunContext): Context of the train running.
|
||||
"""
|
||||
if _is_role_pserver():
|
||||
self._prefix = "PServer_" + str(_get_ps_mode_rank()) + "_" + self._prefix
|
||||
cb_params = run_context.original_args()
|
||||
# save graph (only once)
|
||||
if not self._graph_saved:
|
||||
|
@ -281,8 +283,6 @@ class ModelCheckpoint(Callback):
|
|||
if save_ckpt:
|
||||
cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \
|
||||
+ str(step_num_in_epoch) + ".ckpt"
|
||||
if _is_role_pserver():
|
||||
cur_ckpoint_file = "PServer_" + str(_get_ps_mode_rank()) + "_" + cur_ckpoint_file
|
||||
# update checkpoint file list.
|
||||
self._manager.update_ckpoint_filelist(self._directory, self._prefix)
|
||||
# keep checkpoint files number equal max number.
|
||||
|
|
Loading…
Reference in New Issue