From 57ff5b20776c12da8f718332d243c8d149af269a Mon Sep 17 00:00:00 2001 From: TinaMengtingZhang Date: Thu, 17 Feb 2022 11:16:25 -0500 Subject: [PATCH] fix input parameter dump error --- mindspore/ccsrc/debug/data_dump/e2e_dump.cc | 31 ++++++++++++++++--- mindspore/ccsrc/debug/data_dump/e2e_dump.h | 9 ++++-- mindspore/ccsrc/debug/debugger/debugger.cc | 4 +-- mindspore/ccsrc/debug/debugger/debugger.h | 3 +- .../ccsrc/debug/debugger/debugger_utils.cc | 2 +- 5 files changed, 37 insertions(+), 12 deletions(-) diff --git a/mindspore/ccsrc/debug/data_dump/e2e_dump.cc b/mindspore/ccsrc/debug/data_dump/e2e_dump.cc index b11d4ebc0a6..d968dd8286b 100644 --- a/mindspore/ccsrc/debug/data_dump/e2e_dump.cc +++ b/mindspore/ccsrc/debug/data_dump/e2e_dump.cc @@ -35,6 +35,7 @@ #include "utils/file_utils.h" #include "debug/data_dump/tensor_stat_dump.h" #include "abstract/utils.h" +#include "runtime/hardware/device_context_manager.h" #ifdef ENABLE_DEBUGGER #include "debug/debug_services.h" #include "debug/tensor_load.h" @@ -212,7 +213,8 @@ void E2eDump::DumpInput(const session::KernelGraph *graph, const std::string &du } } -void E2eDump::DumpInputSingleNode(const CNodePtr &node, const std::string &dump_path, const Debugger *debugger) { +void E2eDump::DumpInputSingleNode(const CNodePtr &node, const std::string &dump_path, const Debugger *debugger, + const KernelLaunchInfo *launch_info) { auto &dump_json_parser = DumpJsonParser::GetInstance(); if (!dump_json_parser.InputNeedDump()) { return; @@ -224,11 +226,25 @@ void E2eDump::DumpInputSingleNode(const CNodePtr &node, const std::string &dump_ return; } DumpJsonParser::GetInstance().MatchKernel(kernel_name); - DumpInputImpl(node, trans_flag, dump_path, &kernel_name, debugger); + DumpInputImpl(node, trans_flag, dump_path, &kernel_name, debugger, launch_info); +} + +std::shared_ptr CreateAscendDeviceAddress(const KernelLaunchInfo *launch_info, size_t index, + TypeId type) { + MS_EXCEPTION_IF_NULL(launch_info); + auto addr_ptr = launch_info->inputs_[index]; + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + auto device_id = ms_context->get_param(MS_CTX_DEVICE_ID); + auto device_context = + device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({kAscendDevice, device_id}); + auto format = kOpFormat_DEFAULT; + MS_EXCEPTION_IF_NULL(addr_ptr); + return device_context->CreateDeviceAddress(addr_ptr->addr, addr_ptr->size, format, type); } void E2eDump::DumpInputImpl(const CNodePtr &node, bool trans_flag, const std::string &dump_path, - std::string *kernel_name, const Debugger *debugger) { + std::string *kernel_name, const Debugger *debugger, const KernelLaunchInfo *launch_info) { MS_EXCEPTION_IF_NULL(node); GetFileKernelName(NOT_NULL(kernel_name)); auto input_size = AnfAlgo::GetInputTensorNum(node); @@ -270,6 +286,10 @@ void E2eDump::DumpInputImpl(const CNodePtr &node, bool trans_flag, const std::st if (DumpJsonParser::GetInstance().IsTensorDump()) { if (IsDeviceTargetGPU()) { DumpGPUMemToFile(file_path, tensor_name, *addr, int_shapes, type, device_type, trans_flag, slot, debugger); + } else if (Debugger::GetInstance()->GetAscendKernelByKernelFlag()) { + // load address from launch_info when it's Ascend Kernel by kernel mode. + auto ascend_device_addr = CreateAscendDeviceAddress(launch_info, j, type); + DumpMemToFile(file_path, *ascend_device_addr, int_shapes, type, trans_flag); } else { DumpMemToFile(file_path, *addr, int_shapes, type, trans_flag); } @@ -529,12 +549,13 @@ void E2eDump::DumpData(const session::KernelGraph *graph, uint32_t rank_id, cons } } -bool E2eDump::DumpSingleNodeData(const CNodePtr &node, uint32_t graph_id, uint32_t rank_id, const Debugger *debugger) { +bool E2eDump::DumpSingleNodeData(const CNodePtr &node, uint32_t graph_id, uint32_t rank_id, const Debugger *debugger, + const KernelLaunchInfo *launch_info) { bool success = false; auto &dump_json_parser = DumpJsonParser::GetInstance(); if (dump_json_parser.DumpEnabledForIter()) { std::string dump_path = GenerateDumpPath(graph_id, rank_id); - DumpInputSingleNode(node, dump_path, debugger); + DumpInputSingleNode(node, dump_path, debugger, launch_info); DumpOutputSingleNode(node, dump_path, debugger); success = true; } diff --git a/mindspore/ccsrc/debug/data_dump/e2e_dump.h b/mindspore/ccsrc/debug/data_dump/e2e_dump.h index d4b71d9c30c..f59b5704ead 100644 --- a/mindspore/ccsrc/debug/data_dump/e2e_dump.h +++ b/mindspore/ccsrc/debug/data_dump/e2e_dump.h @@ -29,6 +29,7 @@ #include "proto/dump_data.pb.h" #endif +using mindspore::kernel::KernelLaunchInfo; #ifndef ENABLE_DEBUGGER class Debugger; #endif @@ -53,11 +54,12 @@ class E2eDump { static void DumpParametersData(uint32_t rank_id, const Debugger *debugger); static bool DumpSingleNodeData(const CNodePtr &node, uint32_t graph_id, uint32_t rank_id, - const Debugger *debugger = nullptr); + const Debugger *debugger = nullptr, const KernelLaunchInfo *launch_info = nullptr); // Dump data when task error. static void DumpInputImpl(const CNodePtr &node, bool trans_flag, const std::string &dump_path, - std::string *kernel_name, const Debugger *debugger); + std::string *kernel_name, const Debugger *debugger, + const KernelLaunchInfo *launch_info = nullptr); static void DumpOutputImpl(const CNodePtr &node, bool trans_flag, const std::string &dump_path, std::string *kernel_name, const Debugger *debugger); @@ -78,7 +80,8 @@ class E2eDump { static void DumpInput(const session::KernelGraph *graph, const std::string &dump_path, const Debugger *debugger); - static void DumpInputSingleNode(const CNodePtr &node, const std::string &dump_path, const Debugger *debugger); + static void DumpInputSingleNode(const CNodePtr &node, const std::string &dump_path, const Debugger *debugger, + const KernelLaunchInfo *launch_info = nullptr); static void DumpParameters(const session::KernelGraph *graph, const std::string &dump_path, const Debugger *debugger); diff --git a/mindspore/ccsrc/debug/debugger/debugger.cc b/mindspore/ccsrc/debug/debugger/debugger.cc index 2a6d5af7e70..effa2660a00 100644 --- a/mindspore/ccsrc/debug/debugger/debugger.cc +++ b/mindspore/ccsrc/debug/debugger/debugger.cc @@ -462,10 +462,10 @@ void Debugger::DumpConstantDataAscend(const KernelGraphPtr &graph) { } } -void Debugger::DumpSingleNode(const CNodePtr &node, uint32_t graph_id) { +void Debugger::DumpSingleNode(const CNodePtr &node, uint32_t graph_id, const KernelLaunchInfo *launch_info) { if (debugger_ && debugger_->DebuggerBackendEnabled()) { uint32_t rank_id = GetRankID(); - (void)E2eDump::DumpSingleNodeData(node, graph_id, rank_id, debugger_.get()); + (void)E2eDump::DumpSingleNodeData(node, graph_id, rank_id, debugger_.get(), launch_info); } } diff --git a/mindspore/ccsrc/debug/debugger/debugger.h b/mindspore/ccsrc/debug/debugger/debugger.h index 51c631158a0..4d9ae0a29cd 100644 --- a/mindspore/ccsrc/debug/debugger/debugger.h +++ b/mindspore/ccsrc/debug/debugger/debugger.h @@ -44,6 +44,7 @@ using debugger::WatchNode; using debugger::WatchpointHit; using DeviceTensor = mindspore::device::DeviceAddress; using DeviceTensorPtr = std::shared_ptr; +using mindspore::kernel::KernelLaunchInfo; template using ProtoVector = google::protobuf::RepeatedPtrField; @@ -105,7 +106,7 @@ class Debugger : public std::enable_shared_from_this { void DumpConstantDataAscend(const KernelGraphPtr &graph); - void DumpSingleNode(const CNodePtr &node, uint32_t graph_id); + void DumpSingleNode(const CNodePtr &node, uint32_t graph_id, const KernelLaunchInfo *launch_info = nullptr); void DumpInGraphCompiler(const KernelGraphPtr &kernel_graph); diff --git a/mindspore/ccsrc/debug/debugger/debugger_utils.cc b/mindspore/ccsrc/debug/debugger/debugger_utils.cc index 77069582ca8..e27d97528f0 100644 --- a/mindspore/ccsrc/debug/debugger/debugger_utils.cc +++ b/mindspore/ccsrc/debug/debugger/debugger_utils.cc @@ -168,7 +168,7 @@ void ReadDataAndDump(const CNodePtr &cnode, const KernelLaunchInfo *launch_info, debugger->DumpSingleNode(cnode, graph_id); } else { // for Ascend, node are dumped in root_graph_id directory. - debugger->DumpSingleNode(cnode, root_graph_id); + debugger->DumpSingleNode(cnode, root_graph_id, launch_info); } // Clear Dumped data when online debugger is not enabled if (!debugger->debugger_enabled()) {