From 438e261bbc8f00a1bbe43bb92460e2279ba9b685 Mon Sep 17 00:00:00 2001 From: TinaMengtingZhang Date: Tue, 15 Mar 2022 12:41:28 -0400 Subject: [PATCH] wait until dumping finished --- mindspore/ccsrc/debug/debugger/debugger.cc | 21 +++++++++++++++++++ mindspore/ccsrc/debug/debugger/debugger.h | 2 ++ .../hal/device/ascend_kernel_runtime.cc | 6 ++++++ tests/st/dump/test_data_dump.py | 10 +++++++-- 4 files changed, 37 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/debug/debugger/debugger.cc b/mindspore/ccsrc/debug/debugger/debugger.cc index 2fdb5f3a3d4..47921e8d7da 100644 --- a/mindspore/ccsrc/debug/debugger/debugger.cc +++ b/mindspore/ccsrc/debug/debugger/debugger.cc @@ -1783,6 +1783,27 @@ std::shared_ptr Debugger::LoadDumpDataBuilder(const std::string } void Debugger::ClearDumpDataBuilder(const std::string &node_name) { (void)dump_data_construct_map_.erase(node_name); } + +/* + * Feature group: Dump. + * Target device group: Ascend. + * Runtime category: MindRT. + * Description: This function is used for A+M dump to make sure training processing ends after tensor data have been + * dumped to disk completely. Check if dump_data_construct_map_ is empty to see if no dump task is alive. If not, sleep + * for 500ms and check again. + */ +void Debugger::WaitForWriteFileFinished() { + const int kRetryTimeInMilliseconds = 500; + const int kMaxRecheckCount = 10; + int recheck_cnt = 0; + while (recheck_cnt < kMaxRecheckCount && !dump_data_construct_map_.empty()) { + MS_LOG(INFO) << "Sleep for " << std::to_string(kRetryTimeInMilliseconds) + << " ms to wait for dumping files to finish. Retry count: " << std::to_string(recheck_cnt + 1) << "/" + << std::to_string(kMaxRecheckCount); + std::this_thread::sleep_for(std::chrono::milliseconds(kRetryTimeInMilliseconds)); + recheck_cnt++; + } +} #endif } // namespace mindspore diff --git a/mindspore/ccsrc/debug/debugger/debugger.h b/mindspore/ccsrc/debug/debugger/debugger.h index 7c088f39bd2..d262da6d10f 100644 --- a/mindspore/ccsrc/debug/debugger/debugger.h +++ b/mindspore/ccsrc/debug/debugger/debugger.h @@ -199,6 +199,8 @@ class Debugger : public std::enable_shared_from_this { std::shared_ptr LoadDumpDataBuilder(const std::string &node_name); void ClearDumpDataBuilder(const std::string &node_name); + + void WaitForWriteFileFinished(); #endif private: diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_kernel_runtime.cc b/mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_kernel_runtime.cc index 7e2ea0ee0d1..4f253923b91 100644 --- a/mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_kernel_runtime.cc @@ -249,6 +249,12 @@ bool AscendKernelRuntime::NeedDestroyHccl() { #ifndef ENABLE_SECURITY void AsyncDataDumpUninit() { if (DumpJsonParser::GetInstance().async_dump_enabled()) { +#if ENABLE_D + // When it is A+M dump mode, wait until file save is finished. + if (DumpJsonParser::GetInstance().FileFormatIsNpy()) { + Debugger::GetInstance()->WaitForWriteFileFinished(); + } +#endif if (AdxDataDumpServerUnInit() != 0) { MS_LOG(ERROR) << "Adx data dump server uninit failed"; } diff --git a/tests/st/dump/test_data_dump.py b/tests/st/dump/test_data_dump.py index 1631187d3ef..34a0135b6d1 100644 --- a/tests/st/dump/test_data_dump.py +++ b/tests/st/dump/test_data_dump.py @@ -433,6 +433,12 @@ def check_data_dump(dump_file_path): expect = np.array([[8, 10, 12], [14, 16, 18]], np.float32) assert np.array_equal(output, expect) + +def run_train(): + add = Net() + add(Tensor(x), Tensor(y)) + + def run_saved_data_dump_test(scenario, saved_data): """Run e2e dump on scenario, testing statistic dump""" if sys.platform != 'linux': @@ -445,8 +451,8 @@ def run_saved_data_dump_test(scenario, saved_data): dump_file_path = os.path.join(dump_path, 'rank_0', 'Net', '0', '0') if os.path.isdir(dump_path): shutil.rmtree(dump_path) - add = Net() - add(Tensor(x), Tensor(y)) + exec_network_cmd = 'cd {0}; python -c "from test_data_dump import run_train; run_train()"'.format(os.getcwd()) + _ = os.system(exec_network_cmd) for _ in range(3): if not os.path.exists(dump_file_path): time.sleep(2)