wait until dumping finished

This commit is contained in:
TinaMengtingZhang 2022-03-15 12:41:28 -04:00
parent 3056150514
commit 438e261bbc
4 changed files with 37 additions and 2 deletions

View File

@ -1783,6 +1783,27 @@ std::shared_ptr<DumpDataBuilder> Debugger::LoadDumpDataBuilder(const std::string
}
void Debugger::ClearDumpDataBuilder(const std::string &node_name) { (void)dump_data_construct_map_.erase(node_name); }
/*
* Feature group: Dump.
* Target device group: Ascend.
* Runtime category: MindRT.
* Description: This function is used for A+M dump to make sure training processing ends after tensor data have been
* dumped to disk completely. Check if dump_data_construct_map_ is empty to see if no dump task is alive. If not, sleep
* for 500ms and check again.
*/
void Debugger::WaitForWriteFileFinished() {
const int kRetryTimeInMilliseconds = 500;
const int kMaxRecheckCount = 10;
int recheck_cnt = 0;
while (recheck_cnt < kMaxRecheckCount && !dump_data_construct_map_.empty()) {
MS_LOG(INFO) << "Sleep for " << std::to_string(kRetryTimeInMilliseconds)
<< " ms to wait for dumping files to finish. Retry count: " << std::to_string(recheck_cnt + 1) << "/"
<< std::to_string(kMaxRecheckCount);
std::this_thread::sleep_for(std::chrono::milliseconds(kRetryTimeInMilliseconds));
recheck_cnt++;
}
}
#endif
} // namespace mindspore

View File

@ -199,6 +199,8 @@ class Debugger : public std::enable_shared_from_this<Debugger> {
std::shared_ptr<DumpDataBuilder> LoadDumpDataBuilder(const std::string &node_name);
void ClearDumpDataBuilder(const std::string &node_name);
void WaitForWriteFileFinished();
#endif
private:

View File

@ -249,6 +249,12 @@ bool AscendKernelRuntime::NeedDestroyHccl() {
#ifndef ENABLE_SECURITY
void AsyncDataDumpUninit() {
if (DumpJsonParser::GetInstance().async_dump_enabled()) {
#if ENABLE_D
// When it is A+M dump mode, wait until file save is finished.
if (DumpJsonParser::GetInstance().FileFormatIsNpy()) {
Debugger::GetInstance()->WaitForWriteFileFinished();
}
#endif
if (AdxDataDumpServerUnInit() != 0) {
MS_LOG(ERROR) << "Adx data dump server uninit failed";
}

View File

@ -433,6 +433,12 @@ def check_data_dump(dump_file_path):
expect = np.array([[8, 10, 12], [14, 16, 18]], np.float32)
assert np.array_equal(output, expect)
def run_train():
add = Net()
add(Tensor(x), Tensor(y))
def run_saved_data_dump_test(scenario, saved_data):
"""Run e2e dump on scenario, testing statistic dump"""
if sys.platform != 'linux':
@ -445,8 +451,8 @@ def run_saved_data_dump_test(scenario, saved_data):
dump_file_path = os.path.join(dump_path, 'rank_0', 'Net', '0', '0')
if os.path.isdir(dump_path):
shutil.rmtree(dump_path)
add = Net()
add(Tensor(x), Tensor(y))
exec_network_cmd = 'cd {0}; python -c "from test_data_dump import run_train; run_train()"'.format(os.getcwd())
_ = os.system(exec_network_cmd)
for _ in range(3):
if not os.path.exists(dump_file_path):
time.sleep(2)