diff --git a/mindspore/parallel/_ps_context.py b/mindspore/parallel/_ps_context.py index 5141eab2fb1..999c5adde5c 100644 --- a/mindspore/parallel/_ps_context.py +++ b/mindspore/parallel/_ps_context.py @@ -95,10 +95,11 @@ def _get_ps_context(attr_key): Raises: ValueError: If input key is not attribute in auto parallel context. """ - if key not in _get_ps_context_func_map: - raise ValueError("Get PS context keyword %s is not recognized!" % key) + if attr_key not in _get_ps_context_func_map: + raise ValueError("Get PS context keyword %s is not recognized!" % attr_key) get_func = _get_ps_context_func_map[attr_key] - get_func(attr_key) + value = get_func() + return value def _reset_ps_context(): """ diff --git a/mindspore/train/callback/_checkpoint.py b/mindspore/train/callback/_checkpoint.py index 030fafbea82..91eba0b8ba9 100644 --- a/mindspore/train/callback/_checkpoint.py +++ b/mindspore/train/callback/_checkpoint.py @@ -228,6 +228,8 @@ class ModelCheckpoint(Callback): Args: run_context (RunContext): Context of the train running. """ + if _is_role_pserver(): + self._prefix = "PServer_" + str(_get_ps_mode_rank()) + "_" + self._prefix cb_params = run_context.original_args() # save graph (only once) if not self._graph_saved: @@ -281,8 +283,6 @@ class ModelCheckpoint(Callback): if save_ckpt: cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \ + str(step_num_in_epoch) + ".ckpt" - if _is_role_pserver(): - cur_ckpoint_file = "PServer_" + str(_get_ps_mode_rank()) + "_" + cur_ckpoint_file # update checkpoint file list. self._manager.update_ckpoint_filelist(self._directory, self._prefix) # keep checkpoint files number equal max number.