modify asyn ckpt

This commit is contained in:
changzherui 2021-02-03 23:07:56 +08:00
parent 62d272f648
commit 3b87921ad0
1 changed files with 7 additions and 5 deletions

View File

@ -280,7 +280,10 @@ class ModelCheckpoint(Callback):
os.remove(graph_file_name)
_save_graph(cb_params.train_network, graph_file_name)
self._graph_saved = True
thread_list = threading.enumerate()
for thread in thread_list:
if thread.getName() == "asyn_save_ckpt":
thread.join()
self._save_ckpt(cb_params)
def end(self, run_context):
@ -295,10 +298,9 @@ class ModelCheckpoint(Callback):
self._save_ckpt(cb_params, _to_save_last_ckpt)
thread_list = threading.enumerate()
if len(thread_list) > 1:
for thread in thread_list:
if thread.getName() == "asyn_save_ckpt":
thread.join()
for thread in thread_list:
if thread.getName() == "asyn_save_ckpt":
thread.join()
from mindspore.parallel._cell_wrapper import destroy_allgather_cell
destroy_allgather_cell()