From 99a2ab4b2e20850e7984d302ac7cff4b57801d68 Mon Sep 17 00:00:00 2001 From: changzherui Date: Mon, 20 Jul 2020 22:33:42 +0800 Subject: [PATCH] modify asyn save checkpoint bug --- mindspore/train/callback/_checkpoint.py | 7 +++++++ mindspore/train/serialization.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/mindspore/train/callback/_checkpoint.py b/mindspore/train/callback/_checkpoint.py index a9389fd395e..152e77704eb 100644 --- a/mindspore/train/callback/_checkpoint.py +++ b/mindspore/train/callback/_checkpoint.py @@ -18,6 +18,7 @@ import os import stat import time +import threading import mindspore.context as context from mindspore import log as logger from mindspore._checkparam import check_bool, check_int_non_negative @@ -245,6 +246,12 @@ class ModelCheckpoint(Callback): _to_save_last_ckpt = True self._save_ckpt(cb_params, _to_save_last_ckpt) + thread_list = threading.enumerate() + if len(thread_list) > 1: + for thread in thread_list: + if thread.getName() == "asyn_save_ckpt": + thread.join() + from mindspore.parallel._cell_wrapper import destroy_allgather_cell destroy_allgather_cell() diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index de35981d314..4277797731d 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -160,7 +160,7 @@ def save_checkpoint(parameter_list, ckpt_file_name, async_save=False): data_list[key].append(data) if async_save: - thr = Thread(target=_exec_save, args=(ckpt_file_name, data_list)) + thr = Thread(target=_exec_save, args=(ckpt_file_name, data_list), name="asyn_save_ckpt") thr.start() else: _exec_save(ckpt_file_name, data_list)