forked from mindspore-Ecosystem/mindspore
!3252 modify asyn save checkpoint bug
Merge pull request !3252 from changzherui/master
This commit is contained in:
commit
156c42ef85
|
@ -18,6 +18,7 @@ import os
|
|||
import stat
|
||||
import time
|
||||
|
||||
import threading
|
||||
import mindspore.context as context
|
||||
from mindspore import log as logger
|
||||
from mindspore._checkparam import check_bool, check_int_non_negative
|
||||
|
@ -245,6 +246,12 @@ class ModelCheckpoint(Callback):
|
|||
_to_save_last_ckpt = True
|
||||
self._save_ckpt(cb_params, _to_save_last_ckpt)
|
||||
|
||||
thread_list = threading.enumerate()
|
||||
if len(thread_list) > 1:
|
||||
for thread in thread_list:
|
||||
if thread.getName() == "asyn_save_ckpt":
|
||||
thread.join()
|
||||
|
||||
from mindspore.parallel._cell_wrapper import destroy_allgather_cell
|
||||
destroy_allgather_cell()
|
||||
|
||||
|
|
|
@ -160,7 +160,7 @@ def save_checkpoint(parameter_list, ckpt_file_name, async_save=False):
|
|||
data_list[key].append(data)
|
||||
|
||||
if async_save:
|
||||
thr = Thread(target=_exec_save, args=(ckpt_file_name, data_list))
|
||||
thr = Thread(target=_exec_save, args=(ckpt_file_name, data_list), name="asyn_save_ckpt")
|
||||
thr.start()
|
||||
else:
|
||||
_exec_save(ckpt_file_name, data_list)
|
||||
|
|
Loading…
Reference in New Issue