make gpu dump truly async

This commit is contained in:
Parastoo Ashtari 2021-06-27 18:24:54 -04:00
parent 627b18b293
commit e2a4172560
7 changed files with 189 additions and 63 deletions

View File

@ -76,6 +76,21 @@ void E2eDump::DumpOutput(const session::KernelGraph *graph, const std::string &d
} }
} }
void E2eDump::DumpOutputSingleNode(const CNodePtr &node, const std::string &dump_path, const Debugger *debugger) {
auto &dump_json_parser = DumpJsonParser::GetInstance();
if (!dump_json_parser.OutputNeedDump()) {
return;
}
bool trans_flag = dump_json_parser.trans_flag();
MS_EXCEPTION_IF_NULL(node);
std::string kernel_name = node->fullname_with_scope();
if (!dump_json_parser.NeedDump(kernel_name)) {
return;
}
DumpJsonParser::GetInstance().MatchKernel(kernel_name);
DumpOutputImpl(node, trans_flag, dump_path, &kernel_name, debugger);
}
void E2eDump::DumpOutputImpl(const CNodePtr &node, bool trans_flag, const std::string &dump_path, void E2eDump::DumpOutputImpl(const CNodePtr &node, bool trans_flag, const std::string &dump_path,
std::string *kernel_name, const Debugger *debugger) { std::string *kernel_name, const Debugger *debugger) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
@ -128,6 +143,21 @@ void E2eDump::DumpInput(const session::KernelGraph *graph, const std::string &du
} }
} }
void E2eDump::DumpInputSingleNode(const CNodePtr &node, const std::string &dump_path, const Debugger *debugger) {
auto &dump_json_parser = DumpJsonParser::GetInstance();
if (!dump_json_parser.InputNeedDump()) {
return;
}
bool trans_flag = dump_json_parser.trans_flag();
MS_EXCEPTION_IF_NULL(node);
std::string kernel_name = node->fullname_with_scope();
if (!dump_json_parser.NeedDump(kernel_name)) {
return;
}
DumpJsonParser::GetInstance().MatchKernel(kernel_name);
DumpInputImpl(node, trans_flag, dump_path, &kernel_name, debugger);
}
void E2eDump::DumpInputImpl(const CNodePtr &node, bool trans_flag, const std::string &dump_path, void E2eDump::DumpInputImpl(const CNodePtr &node, bool trans_flag, const std::string &dump_path,
std::string *kernel_name, const Debugger *debugger) { std::string *kernel_name, const Debugger *debugger) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
@ -372,6 +402,32 @@ bool E2eDump::DumpData(const session::KernelGraph *graph, uint32_t rank_id, cons
return success; return success;
} }
bool E2eDump::DumpSingleNodeData(const CNodePtr &node, uint32_t graph_id, uint32_t rank_id, const Debugger *debugger) {
bool success = false;
auto &dump_json_parser = DumpJsonParser::GetInstance();
if (dump_json_parser.GetIterDumpFlag()) {
std::string dump_path = GenerateDumpPath(graph_id, rank_id);
DumpInputSingleNode(node, dump_path, debugger);
DumpOutputSingleNode(node, dump_path, debugger);
success = true;
}
return success;
}
bool E2eDump::DumpParametersAndConstData(const session::KernelGraph *graph, uint32_t rank_id,
const Debugger *debugger) {
bool success = false;
uint32_t graph_id = graph->graph_id();
auto &dump_json_parser = DumpJsonParser::GetInstance();
if (dump_json_parser.GetIterDumpFlag()) {
MS_LOG(INFO) << "DumpParametersAndConst. Current iteration is " << dump_json_parser.cur_dump_iter();
MS_LOG(INFO) << "Current graph id is " << graph_id;
std::string dump_path = GenerateDumpPath(graph_id, rank_id);
DumpParametersAndConst(graph, dump_path, debugger);
success = true;
}
return success;
}
bool E2eDump::isDatasetGraph(const session::KernelGraph *graph) { bool E2eDump::isDatasetGraph(const session::KernelGraph *graph) {
// check if there is GetNext or InitDataSetQueue node // check if there is GetNext or InitDataSetQueue node
const auto &nodes = graph->execution_order(); const auto &nodes = graph->execution_order();

View File

@ -36,6 +36,12 @@ class E2eDump {
~E2eDump() = default; ~E2eDump() = default;
static void DumpSetup(const session::KernelGraph *graph, uint32_t rank_id); static void DumpSetup(const session::KernelGraph *graph, uint32_t rank_id);
static bool DumpData(const session::KernelGraph *graph, uint32_t rank_id, const Debugger *debugger = nullptr); static bool DumpData(const session::KernelGraph *graph, uint32_t rank_id, const Debugger *debugger = nullptr);
static bool DumpParametersAndConstData(const session::KernelGraph *graph, uint32_t rank_id, const Debugger *debugger);
static bool DumpSingleNodeData(const CNodePtr &node, uint32_t graph_id, uint32_t rank_id,
const Debugger *debugger = nullptr);
static bool isDatasetGraph(const session::KernelGraph *graph); static bool isDatasetGraph(const session::KernelGraph *graph);
// Dump data when task error. // Dump data when task error.
static void DumpInputImpl(const CNodePtr &node, bool trans_flag, const std::string &dump_path, static void DumpInputImpl(const CNodePtr &node, bool trans_flag, const std::string &dump_path,
@ -45,8 +51,13 @@ class E2eDump {
private: private:
static void DumpOutput(const session::KernelGraph *graph, const std::string &dump_path, const Debugger *debugger); static void DumpOutput(const session::KernelGraph *graph, const std::string &dump_path, const Debugger *debugger);
static void DumpOutputSingleNode(const CNodePtr &node, const std::string &dump_path, const Debugger *debugger);
static void DumpInput(const session::KernelGraph *graph, const std::string &dump_path, const Debugger *debugger); static void DumpInput(const session::KernelGraph *graph, const std::string &dump_path, const Debugger *debugger);
static void DumpInputSingleNode(const CNodePtr &node, const std::string &dump_path, const Debugger *debugger);
static void DumpParametersAndConst(const session::KernelGraph *graph, const std::string &dump_path, static void DumpParametersAndConst(const session::KernelGraph *graph, const std::string &dump_path,
const Debugger *debugger); const Debugger *debugger);

View File

@ -317,7 +317,6 @@ void Debugger::PreExecuteGraphDebugger(const std::vector<KernelGraphPtr> &graphs
} }
void Debugger::PreExecute(const KernelGraphPtr &graph_ptr, uint32_t graph_sum) { void Debugger::PreExecute(const KernelGraphPtr &graph_ptr, uint32_t graph_sum) {
// access lock for public method // access lock for public method
std::lock_guard<std::mutex> a_lock(access_lock_); std::lock_guard<std::mutex> a_lock(access_lock_);
CheckDatasetSinkMode(); CheckDatasetSinkMode();
auto graph_id = graph_ptr->graph_id(); auto graph_id = graph_ptr->graph_id();
@ -392,7 +391,7 @@ bool Debugger::DumpDataEnabledIteration() const {
return false; return false;
} }
void Debugger::Dump(const KernelGraphPtr &kernel_graph) const { uint32_t Debugger::GetRankID() {
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET); std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
@ -400,23 +399,28 @@ void Debugger::Dump(const KernelGraphPtr &kernel_graph) const {
const auto &device_context = const auto &device_context =
device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_target, device_id}); device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_target, device_id});
uint32_t rank_id = device_context->GetRankID(); uint32_t rank_id = device_context->GetRankID();
return rank_id;
}
void Debugger::Dump(const KernelGraphPtr &kernel_graph) const {
uint32_t rank_id = GetRankID();
if (debugger_->DebuggerBackendEnabled()) { if (debugger_->DebuggerBackendEnabled()) {
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
E2eDump::DumpData(kernel_graph.get(), rank_id, debugger_.get()); E2eDump::DumpParametersAndConstData(kernel_graph.get(), rank_id, debugger_.get());
} else { } else {
DumpJsonParser::GetInstance().UpdateDumpIter(); DumpJsonParser::GetInstance().UpdateDumpIter();
} }
} }
void Debugger::DumpSingleNode(const CNodePtr &node, uint32_t graph_id) {
if (debugger_->DebuggerBackendEnabled()) {
uint32_t rank_id = GetRankID();
E2eDump::DumpSingleNodeData(node, graph_id, rank_id, debugger_.get());
}
}
void Debugger::DumpSetup(const KernelGraphPtr &kernel_graph) const { void Debugger::DumpSetup(const KernelGraphPtr &kernel_graph) const {
MS_LOG(INFO) << "Start!"; MS_LOG(INFO) << "Start!";
auto ms_context = MsContext::GetInstance(); uint32_t rank_id = GetRankID();
MS_EXCEPTION_IF_NULL(ms_context);
std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
const auto &device_context =
device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_target, device_id});
uint32_t rank_id = device_context->GetRankID();
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
E2eDump::DumpSetup(kernel_graph.get(), rank_id); E2eDump::DumpSetup(kernel_graph.get(), rank_id);
MS_LOG(INFO) << "Finish!"; MS_LOG(INFO) << "Finish!";
@ -425,13 +429,7 @@ void Debugger::DumpInGraphCompiler(const KernelGraphPtr &kernel_graph) {
// This function will be called for new GPU runtime using MindRTBackend // This function will be called for new GPU runtime using MindRTBackend
auto &json_parser = DumpJsonParser::GetInstance(); auto &json_parser = DumpJsonParser::GetInstance();
if (json_parser.e2e_dump_enabled()) { if (json_parser.e2e_dump_enabled()) {
auto ms_context = MsContext::GetInstance(); uint32_t rank_id = GetRankID();
MS_EXCEPTION_IF_NULL(ms_context);
std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
const auto &device_context =
device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_target, device_id});
uint32_t rank_id = device_context->GetRankID();
kernel_graph->set_root_graph_id(kernel_graph->graph_id()); kernel_graph->set_root_graph_id(kernel_graph->graph_id());
std::string final_graph = "trace_code_graph_" + std::to_string(kernel_graph->graph_id()); std::string final_graph = "trace_code_graph_" + std::to_string(kernel_graph->graph_id());
std::string root_dir = json_parser.path() + "/rank_" + std::to_string(rank_id); std::string root_dir = json_parser.path() + "/rank_" + std::to_string(rank_id);
@ -443,23 +441,32 @@ void Debugger::DumpInGraphCompiler(const KernelGraphPtr &kernel_graph) {
kernel_graph->execution_order()); kernel_graph->execution_order());
} }
} }
void Debugger::PostExecuteGraphDebugger(const std::vector<KernelGraphPtr> &graphs) {
void Debugger::PostExecuteGraphDebugger() {
// Only GPU is supported for MindRTBackend // Only GPU is supported for MindRTBackend
if (device_target_ != kGPUDevice) { if (device_target_ != kGPUDevice) {
return; return;
} }
for (size_t graph_index = 0; graph_index < graphs.size(); ++graph_index) { // LoadParametersAndConst for all the graphs
const auto &graph = graphs[graph_index]; for (auto graph : graph_ptr_list_) {
bool dump_enabled = debugger_->DumpDataEnabledIteration(); debugger_->LoadParametersAndConst(graph);
// debug used for dump }
if (debugger_ && dump_enabled) { bool dump_enabled = debugger_->DumpDataEnabledIteration();
// debug used for dump
if (debugger_ && dump_enabled) {
// Dump Parameters and consts
for (auto graph : graph_ptr_list_) {
debugger_->Dump(graph); debugger_->Dump(graph);
} else { if (!debugger_->debugger_enabled()) {
DumpJsonParser::GetInstance().UpdateDumpIter(); debugger_->ClearCurrentData();
} }
if (debugger_) {
debugger_->PostExecute();
} }
} else {
DumpJsonParser::GetInstance().UpdateDumpIter();
}
if (debugger_) {
debugger_->PostExecute();
} }
} }
@ -1340,6 +1347,24 @@ void Debugger::LoadParametersAndConst() {
} }
} }
void Debugger::LoadParametersAndConst(const KernelGraphPtr &graph) {
if (!(debugger_enabled_ || CheckDebuggerDumpEnabled())) return;
MS_EXCEPTION_IF_NULL(graph);
// load parameters
MS_LOG(INFO) << "Start to load Parameters for graph " << graph->graph_id();
const auto &parameters = graph_ptr_->inputs();
for (auto &item : parameters) {
LoadSingleAnfnode(item, PARAMETER_OUTPUT_INDEX);
}
// load value nodes
// get all constant avlues from the graph
MS_LOG(INFO) << "Start to load value nodes for graph " << graph->graph_id();
const auto value_nodes = graph_ptr_->graph_value_nodes();
for (auto &item : value_nodes) {
LoadSingleAnfnode(item, VALUE_NODE_OUTPUT_INDEX);
}
}
void Debugger::LoadGraphOutputs() { void Debugger::LoadGraphOutputs() {
if (!(debugger_enabled() && device_target_ == kAscendDevice)) return; if (!(debugger_enabled() && device_target_ == kAscendDevice)) return;
MS_EXCEPTION_IF_NULL(graph_ptr_); MS_EXCEPTION_IF_NULL(graph_ptr_);

View File

@ -85,13 +85,17 @@ class Debugger : public std::enable_shared_from_this<Debugger> {
bool DumpDataEnabledIteration() const; bool DumpDataEnabledIteration() const;
static uint32_t GetRankID();
void Dump(const KernelGraphPtr &kernel_graph) const; void Dump(const KernelGraphPtr &kernel_graph) const;
void DumpSingleNode(const CNodePtr &node, uint32_t graph_id);
void DumpSetup(const KernelGraphPtr &kernel_graph) const; void DumpSetup(const KernelGraphPtr &kernel_graph) const;
void DumpInGraphCompiler(const KernelGraphPtr &kernel_graph); void DumpInGraphCompiler(const KernelGraphPtr &kernel_graph);
void PostExecuteGraphDebugger(const std::vector<KernelGraphPtr> &graphs); void PostExecuteGraphDebugger();
bool ReadNodeDataRequired(const CNodePtr &kernel) const; bool ReadNodeDataRequired(const CNodePtr &kernel) const;
@ -141,6 +145,8 @@ class Debugger : public std::enable_shared_from_this<Debugger> {
void LoadParametersAndConst(); void LoadParametersAndConst();
void LoadParametersAndConst(const KernelGraphPtr &graph);
void UpdateStepNum(const session::KernelGraph *graph); void UpdateStepNum(const session::KernelGraph *graph);
void UpdateStepNumGPU(); void UpdateStepNumGPU();

View File

@ -27,6 +27,7 @@
using mindspore::kernel::AddressPtr; using mindspore::kernel::AddressPtr;
using AddressPtrList = std::vector<mindspore::kernel::AddressPtr>; using AddressPtrList = std::vector<mindspore::kernel::AddressPtr>;
using KernelGraph = mindspore::session::KernelGraph;
#endif #endif
namespace mindspore { namespace mindspore {
namespace runtime { namespace runtime {
@ -100,6 +101,56 @@ void LoadOutputs(const CNodePtr &cnode, const KernelLaunchInfo *launch_info_, ui
} }
} }
} }
bool CheckReadData(const CNodePtr &cnode) {
auto debugger = Debugger::GetInstance();
if (!debugger) {
return false;
}
bool read_data = false;
auto &dump_json_parser = DumpJsonParser::GetInstance();
bool dump_enabled = debugger->DumpDataEnabledIteration();
std::string kernel_name = cnode->fullname_with_scope();
if (dump_enabled) {
auto dump_mode = dump_json_parser.dump_mode();
// dump the node if dump_mode is 0, which means all kernels, or if this kernel is in the kernels list
if ((dump_mode == 0) || ((dump_mode == 1) && dump_json_parser.NeedDump(kernel_name))) {
read_data = true;
}
} else if (debugger->debugger_enabled()) {
read_data = debugger->ReadNodeDataRequired(cnode);
}
return read_data;
}
void ReadDataAndDump(const CNodePtr &cnode, const KernelLaunchInfo *launch_info_, uint32_t exec_order_) {
auto debugger = Debugger::GetInstance();
if (!debugger) {
return;
}
auto &dump_json_parser = DumpJsonParser::GetInstance();
bool dump_enabled = debugger->DumpDataEnabledIteration();
if (debugger->debugger_enabled() || dump_json_parser.InputNeedDump()) {
LoadInputs(cnode, launch_info_, exec_order_);
}
if (debugger->debugger_enabled() || dump_json_parser.OutputNeedDump()) {
LoadOutputs(cnode, launch_info_, exec_order_);
}
// Dump kernel
if (dump_enabled) {
auto kernel_graph = std::dynamic_pointer_cast<KernelGraph>(cnode->func_graph());
MS_EXCEPTION_IF_NULL(kernel_graph);
auto graph_id = kernel_graph->graph_id();
debugger->DumpSingleNode(cnode, graph_id);
// Clear Dumped data when online debugger is not enabled
if (!debugger->debugger_enabled()) {
debugger->ClearCurrentData();
}
}
// check if the node is last kernel
bool last_kernel = !AnfAlgo::IsInplaceNode(cnode, "skip");
debugger->PostExecuteNode(cnode, last_kernel);
}
#endif #endif
void DebugActor::Debug(const AnfNodePtr &node, const KernelLaunchInfo *launch_info_, void DebugActor::Debug(const AnfNodePtr &node, const KernelLaunchInfo *launch_info_,
@ -108,36 +159,19 @@ void DebugActor::Debug(const AnfNodePtr &node, const KernelLaunchInfo *launch_in
MS_EXCEPTION_IF_NULL(device_context); MS_EXCEPTION_IF_NULL(device_context);
MS_EXCEPTION_IF_NULL(op_context); MS_EXCEPTION_IF_NULL(op_context);
MS_EXCEPTION_IF_NULL(from_aid); MS_EXCEPTION_IF_NULL(from_aid);
// todo debug. // todo debug.
MS_LOG(INFO) << "DebugActor is called";
#ifdef ENABLE_GPU #ifdef ENABLE_GPU
if (node->isa<CNode>()) { if (node->isa<CNode>()) {
const auto &cnode = node->cast<CNodePtr>(); const auto &cnode = node->cast<CNodePtr>();
auto debugger = Debugger::GetInstance(); auto debugger = Debugger::GetInstance();
if (debugger) { if (debugger) {
std::string kernel_name = cnode->fullname_with_scope(); std::string kernel_name = cnode->fullname_with_scope();
MS_LOG(INFO) << "kernel_name is " << kernel_name;
debugger->SetCurNode(kernel_name); debugger->SetCurNode(kernel_name);
bool read_data = false; bool read_data = CheckReadData(cnode);
auto &dump_json_parser = DumpJsonParser::GetInstance();
bool dump_enabled = debugger->DumpDataEnabledIteration();
if (dump_enabled) {
auto dump_mode = dump_json_parser.dump_mode();
// dump the node if dump_mode is 0, which means all kernels, or if this kernel is in the kernels list
if ((dump_mode == 0) || ((dump_mode == 1) && dump_json_parser.NeedDump(kernel_name))) {
read_data = true;
}
} else if (debugger->debugger_enabled()) {
read_data = debugger->ReadNodeDataRequired(cnode);
}
if (read_data) { if (read_data) {
if (debugger->debugger_enabled() || dump_json_parser.InputNeedDump()) { ReadDataAndDump(cnode, launch_info_, exec_order_);
LoadInputs(cnode, launch_info_, exec_order_);
}
if (debugger->debugger_enabled() || dump_json_parser.OutputNeedDump()) {
LoadOutputs(cnode, launch_info_, exec_order_);
}
// check if the node is last kernel
bool last_kernel = !AnfAlgo::IsInplaceNode(cnode, "skip");
debugger->PostExecuteNode(cnode, last_kernel);
} }
} }
exec_order_ += 1; exec_order_ += 1;
@ -150,14 +184,15 @@ void DebugActor::Debug(const AnfNodePtr &node, const KernelLaunchInfo *launch_in
void DebugActor::DebugOnStepEnd(OpContext<DeviceTensor> *op_context, const AID *from_aid) { void DebugActor::DebugOnStepEnd(OpContext<DeviceTensor> *op_context, const AID *from_aid) {
MS_EXCEPTION_IF_NULL(op_context); MS_EXCEPTION_IF_NULL(op_context);
MS_EXCEPTION_IF_NULL(from_aid); MS_EXCEPTION_IF_NULL(from_aid);
// todo debug. // todo debug.
MS_LOG(INFO) << "DebugActor::DebugOnStepEnd is called";
#ifdef ENABLE_GPU #ifdef ENABLE_GPU
auto debugger = Debugger::GetInstance(); auto debugger = Debugger::GetInstance();
if (debugger) { if (debugger) {
debugger->Debugger::UpdateStepNumGPU(); debugger->Debugger::UpdateStepNumGPU();
debugger->Debugger::LoadParametersAndConst();
// Reset exec_order for the next step // Reset exec_order for the next step
exec_order_ = 0; exec_order_ = 0;
debugger->Debugger::PostExecuteGraphDebugger();
} }
#endif #endif
// Call back to the from actor to process after debug finished. // Call back to the from actor to process after debug finished.

View File

@ -304,16 +304,15 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic
} }
graph->set_is_all_nop_node(opt::IsAllNopNode(graph.get())); graph->set_is_all_nop_node(opt::IsAllNopNode(graph.get()));
#ifdef ENABLE_DEBUGGER
auto debugger = Debugger::GetInstance();
debugger->DumpInGraphCompiler(graph);
#endif
MS_EXCEPTION_IF_NULL(session_); MS_EXCEPTION_IF_NULL(session_);
session_->InitAllBucket(graph, device_context); session_->InitAllBucket(graph, device_context);
session_->SetSummaryNodes(graph.get()); session_->SetSummaryNodes(graph.get());
SetSummaryNodesRefCount(graph.get()); SetSummaryNodesRefCount(graph.get());
#ifdef ENABLE_DEBUGGER #ifdef ENABLE_DEBUGGER
auto debugger = Debugger::GetInstance();
debugger->DumpInGraphCompiler(graph);
if (debugger && debugger->DebuggerBackendEnabled()) { if (debugger && debugger->DebuggerBackendEnabled()) {
debugger->LoadGraphs(graph); debugger->LoadGraphs(graph);
} }

View File

@ -606,12 +606,6 @@ void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args,
if (!runtime::GraphScheduler::GetInstance().Run(actor_set)) { if (!runtime::GraphScheduler::GetInstance().Run(actor_set)) {
MS_LOG(EXCEPTION) << "The actor runs failed, actor name: " << actor_set->name_; MS_LOG(EXCEPTION) << "The actor runs failed, actor name: " << actor_set->name_;
} }
// Debugger post-execute graph.
#ifdef ENABLE_DEBUGGER
if (Debugger::GetInstance()->DebuggerBackendEnabled()) {
Debugger::GetInstance()->PostExecuteGraphDebugger(graph_compiler_info.graphs_);
}
#endif
// Sync device stream. // Sync device stream.
const auto &first_device_context = graph_compiler_info.device_contexts_[0]; const auto &first_device_context = graph_compiler_info.device_contexts_[0];