make gpu dump truly async

This commit is contained in:
Parastoo Ashtari 2021-06-27 18:24:54 -04:00
parent 627b18b293
commit e2a4172560
7 changed files with 189 additions and 63 deletions

View File

@ -76,6 +76,21 @@ void E2eDump::DumpOutput(const session::KernelGraph *graph, const std::string &d
}
}
void E2eDump::DumpOutputSingleNode(const CNodePtr &node, const std::string &dump_path, const Debugger *debugger) {
auto &dump_json_parser = DumpJsonParser::GetInstance();
if (!dump_json_parser.OutputNeedDump()) {
return;
}
bool trans_flag = dump_json_parser.trans_flag();
MS_EXCEPTION_IF_NULL(node);
std::string kernel_name = node->fullname_with_scope();
if (!dump_json_parser.NeedDump(kernel_name)) {
return;
}
DumpJsonParser::GetInstance().MatchKernel(kernel_name);
DumpOutputImpl(node, trans_flag, dump_path, &kernel_name, debugger);
}
void E2eDump::DumpOutputImpl(const CNodePtr &node, bool trans_flag, const std::string &dump_path,
std::string *kernel_name, const Debugger *debugger) {
MS_EXCEPTION_IF_NULL(node);
@ -128,6 +143,21 @@ void E2eDump::DumpInput(const session::KernelGraph *graph, const std::string &du
}
}
void E2eDump::DumpInputSingleNode(const CNodePtr &node, const std::string &dump_path, const Debugger *debugger) {
auto &dump_json_parser = DumpJsonParser::GetInstance();
if (!dump_json_parser.InputNeedDump()) {
return;
}
bool trans_flag = dump_json_parser.trans_flag();
MS_EXCEPTION_IF_NULL(node);
std::string kernel_name = node->fullname_with_scope();
if (!dump_json_parser.NeedDump(kernel_name)) {
return;
}
DumpJsonParser::GetInstance().MatchKernel(kernel_name);
DumpInputImpl(node, trans_flag, dump_path, &kernel_name, debugger);
}
void E2eDump::DumpInputImpl(const CNodePtr &node, bool trans_flag, const std::string &dump_path,
std::string *kernel_name, const Debugger *debugger) {
MS_EXCEPTION_IF_NULL(node);
@ -372,6 +402,32 @@ bool E2eDump::DumpData(const session::KernelGraph *graph, uint32_t rank_id, cons
return success;
}
bool E2eDump::DumpSingleNodeData(const CNodePtr &node, uint32_t graph_id, uint32_t rank_id, const Debugger *debugger) {
bool success = false;
auto &dump_json_parser = DumpJsonParser::GetInstance();
if (dump_json_parser.GetIterDumpFlag()) {
std::string dump_path = GenerateDumpPath(graph_id, rank_id);
DumpInputSingleNode(node, dump_path, debugger);
DumpOutputSingleNode(node, dump_path, debugger);
success = true;
}
return success;
}
bool E2eDump::DumpParametersAndConstData(const session::KernelGraph *graph, uint32_t rank_id,
const Debugger *debugger) {
bool success = false;
uint32_t graph_id = graph->graph_id();
auto &dump_json_parser = DumpJsonParser::GetInstance();
if (dump_json_parser.GetIterDumpFlag()) {
MS_LOG(INFO) << "DumpParametersAndConst. Current iteration is " << dump_json_parser.cur_dump_iter();
MS_LOG(INFO) << "Current graph id is " << graph_id;
std::string dump_path = GenerateDumpPath(graph_id, rank_id);
DumpParametersAndConst(graph, dump_path, debugger);
success = true;
}
return success;
}
bool E2eDump::isDatasetGraph(const session::KernelGraph *graph) {
// check if there is GetNext or InitDataSetQueue node
const auto &nodes = graph->execution_order();

View File

@ -36,6 +36,12 @@ class E2eDump {
~E2eDump() = default;
static void DumpSetup(const session::KernelGraph *graph, uint32_t rank_id);
static bool DumpData(const session::KernelGraph *graph, uint32_t rank_id, const Debugger *debugger = nullptr);
static bool DumpParametersAndConstData(const session::KernelGraph *graph, uint32_t rank_id, const Debugger *debugger);
static bool DumpSingleNodeData(const CNodePtr &node, uint32_t graph_id, uint32_t rank_id,
const Debugger *debugger = nullptr);
static bool isDatasetGraph(const session::KernelGraph *graph);
// Dump data when task error.
static void DumpInputImpl(const CNodePtr &node, bool trans_flag, const std::string &dump_path,
@ -45,8 +51,13 @@ class E2eDump {
private:
static void DumpOutput(const session::KernelGraph *graph, const std::string &dump_path, const Debugger *debugger);
static void DumpOutputSingleNode(const CNodePtr &node, const std::string &dump_path, const Debugger *debugger);
static void DumpInput(const session::KernelGraph *graph, const std::string &dump_path, const Debugger *debugger);
static void DumpInputSingleNode(const CNodePtr &node, const std::string &dump_path, const Debugger *debugger);
static void DumpParametersAndConst(const session::KernelGraph *graph, const std::string &dump_path,
const Debugger *debugger);

View File

@ -317,7 +317,6 @@ void Debugger::PreExecuteGraphDebugger(const std::vector<KernelGraphPtr> &graphs
}
void Debugger::PreExecute(const KernelGraphPtr &graph_ptr, uint32_t graph_sum) {
// access lock for public method
std::lock_guard<std::mutex> a_lock(access_lock_);
CheckDatasetSinkMode();
auto graph_id = graph_ptr->graph_id();
@ -392,7 +391,7 @@ bool Debugger::DumpDataEnabledIteration() const {
return false;
}
void Debugger::Dump(const KernelGraphPtr &kernel_graph) const {
uint32_t Debugger::GetRankID() {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
@ -400,23 +399,28 @@ void Debugger::Dump(const KernelGraphPtr &kernel_graph) const {
const auto &device_context =
device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_target, device_id});
uint32_t rank_id = device_context->GetRankID();
return rank_id;
}
void Debugger::Dump(const KernelGraphPtr &kernel_graph) const {
uint32_t rank_id = GetRankID();
if (debugger_->DebuggerBackendEnabled()) {
MS_EXCEPTION_IF_NULL(kernel_graph);
E2eDump::DumpData(kernel_graph.get(), rank_id, debugger_.get());
E2eDump::DumpParametersAndConstData(kernel_graph.get(), rank_id, debugger_.get());
} else {
DumpJsonParser::GetInstance().UpdateDumpIter();
}
}
void Debugger::DumpSingleNode(const CNodePtr &node, uint32_t graph_id) {
if (debugger_->DebuggerBackendEnabled()) {
uint32_t rank_id = GetRankID();
E2eDump::DumpSingleNodeData(node, graph_id, rank_id, debugger_.get());
}
}
void Debugger::DumpSetup(const KernelGraphPtr &kernel_graph) const {
MS_LOG(INFO) << "Start!";
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
const auto &device_context =
device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_target, device_id});
uint32_t rank_id = device_context->GetRankID();
uint32_t rank_id = GetRankID();
MS_EXCEPTION_IF_NULL(kernel_graph);
E2eDump::DumpSetup(kernel_graph.get(), rank_id);
MS_LOG(INFO) << "Finish!";
@ -425,13 +429,7 @@ void Debugger::DumpInGraphCompiler(const KernelGraphPtr &kernel_graph) {
// This function will be called for new GPU runtime using MindRTBackend
auto &json_parser = DumpJsonParser::GetInstance();
if (json_parser.e2e_dump_enabled()) {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
const auto &device_context =
device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_target, device_id});
uint32_t rank_id = device_context->GetRankID();
uint32_t rank_id = GetRankID();
kernel_graph->set_root_graph_id(kernel_graph->graph_id());
std::string final_graph = "trace_code_graph_" + std::to_string(kernel_graph->graph_id());
std::string root_dir = json_parser.path() + "/rank_" + std::to_string(rank_id);
@ -443,17 +441,27 @@ void Debugger::DumpInGraphCompiler(const KernelGraphPtr &kernel_graph) {
kernel_graph->execution_order());
}
}
void Debugger::PostExecuteGraphDebugger(const std::vector<KernelGraphPtr> &graphs) {
void Debugger::PostExecuteGraphDebugger() {
// Only GPU is supported for MindRTBackend
if (device_target_ != kGPUDevice) {
return;
}
for (size_t graph_index = 0; graph_index < graphs.size(); ++graph_index) {
const auto &graph = graphs[graph_index];
// LoadParametersAndConst for all the graphs
for (auto graph : graph_ptr_list_) {
debugger_->LoadParametersAndConst(graph);
}
bool dump_enabled = debugger_->DumpDataEnabledIteration();
// debug used for dump
if (debugger_ && dump_enabled) {
// Dump Parameters and consts
for (auto graph : graph_ptr_list_) {
debugger_->Dump(graph);
if (!debugger_->debugger_enabled()) {
debugger_->ClearCurrentData();
}
}
} else {
DumpJsonParser::GetInstance().UpdateDumpIter();
}
@ -461,7 +469,6 @@ void Debugger::PostExecuteGraphDebugger(const std::vector<KernelGraphPtr> &graph
debugger_->PostExecute();
}
}
}
void Debugger::PostExecute() {
// access lock for public method
@ -1340,6 +1347,24 @@ void Debugger::LoadParametersAndConst() {
}
}
void Debugger::LoadParametersAndConst(const KernelGraphPtr &graph) {
if (!(debugger_enabled_ || CheckDebuggerDumpEnabled())) return;
MS_EXCEPTION_IF_NULL(graph);
// load parameters
MS_LOG(INFO) << "Start to load Parameters for graph " << graph->graph_id();
const auto &parameters = graph_ptr_->inputs();
for (auto &item : parameters) {
LoadSingleAnfnode(item, PARAMETER_OUTPUT_INDEX);
}
// load value nodes
// get all constant avlues from the graph
MS_LOG(INFO) << "Start to load value nodes for graph " << graph->graph_id();
const auto value_nodes = graph_ptr_->graph_value_nodes();
for (auto &item : value_nodes) {
LoadSingleAnfnode(item, VALUE_NODE_OUTPUT_INDEX);
}
}
void Debugger::LoadGraphOutputs() {
if (!(debugger_enabled() && device_target_ == kAscendDevice)) return;
MS_EXCEPTION_IF_NULL(graph_ptr_);

View File

@ -85,13 +85,17 @@ class Debugger : public std::enable_shared_from_this<Debugger> {
bool DumpDataEnabledIteration() const;
static uint32_t GetRankID();
void Dump(const KernelGraphPtr &kernel_graph) const;
void DumpSingleNode(const CNodePtr &node, uint32_t graph_id);
void DumpSetup(const KernelGraphPtr &kernel_graph) const;
void DumpInGraphCompiler(const KernelGraphPtr &kernel_graph);
void PostExecuteGraphDebugger(const std::vector<KernelGraphPtr> &graphs);
void PostExecuteGraphDebugger();
bool ReadNodeDataRequired(const CNodePtr &kernel) const;
@ -141,6 +145,8 @@ class Debugger : public std::enable_shared_from_this<Debugger> {
void LoadParametersAndConst();
void LoadParametersAndConst(const KernelGraphPtr &graph);
void UpdateStepNum(const session::KernelGraph *graph);
void UpdateStepNumGPU();

View File

@ -27,6 +27,7 @@
using mindspore::kernel::AddressPtr;
using AddressPtrList = std::vector<mindspore::kernel::AddressPtr>;
using KernelGraph = mindspore::session::KernelGraph;
#endif
namespace mindspore {
namespace runtime {
@ -100,25 +101,16 @@ void LoadOutputs(const CNodePtr &cnode, const KernelLaunchInfo *launch_info_, ui
}
}
}
#endif
void DebugActor::Debug(const AnfNodePtr &node, const KernelLaunchInfo *launch_info_,
const DeviceContext *device_context, OpContext<DeviceTensor> *op_context, const AID *from_aid) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(device_context);
MS_EXCEPTION_IF_NULL(op_context);
MS_EXCEPTION_IF_NULL(from_aid);
// todo debug.
#ifdef ENABLE_GPU
if (node->isa<CNode>()) {
const auto &cnode = node->cast<CNodePtr>();
bool CheckReadData(const CNodePtr &cnode) {
auto debugger = Debugger::GetInstance();
if (debugger) {
std::string kernel_name = cnode->fullname_with_scope();
debugger->SetCurNode(kernel_name);
if (!debugger) {
return false;
}
bool read_data = false;
auto &dump_json_parser = DumpJsonParser::GetInstance();
bool dump_enabled = debugger->DumpDataEnabledIteration();
std::string kernel_name = cnode->fullname_with_scope();
if (dump_enabled) {
auto dump_mode = dump_json_parser.dump_mode();
// dump the node if dump_mode is 0, which means all kernels, or if this kernel is in the kernels list
@ -128,17 +120,59 @@ void DebugActor::Debug(const AnfNodePtr &node, const KernelLaunchInfo *launch_in
} else if (debugger->debugger_enabled()) {
read_data = debugger->ReadNodeDataRequired(cnode);
}
if (read_data) {
return read_data;
}
void ReadDataAndDump(const CNodePtr &cnode, const KernelLaunchInfo *launch_info_, uint32_t exec_order_) {
auto debugger = Debugger::GetInstance();
if (!debugger) {
return;
}
auto &dump_json_parser = DumpJsonParser::GetInstance();
bool dump_enabled = debugger->DumpDataEnabledIteration();
if (debugger->debugger_enabled() || dump_json_parser.InputNeedDump()) {
LoadInputs(cnode, launch_info_, exec_order_);
}
if (debugger->debugger_enabled() || dump_json_parser.OutputNeedDump()) {
LoadOutputs(cnode, launch_info_, exec_order_);
}
// Dump kernel
if (dump_enabled) {
auto kernel_graph = std::dynamic_pointer_cast<KernelGraph>(cnode->func_graph());
MS_EXCEPTION_IF_NULL(kernel_graph);
auto graph_id = kernel_graph->graph_id();
debugger->DumpSingleNode(cnode, graph_id);
// Clear Dumped data when online debugger is not enabled
if (!debugger->debugger_enabled()) {
debugger->ClearCurrentData();
}
}
// check if the node is last kernel
bool last_kernel = !AnfAlgo::IsInplaceNode(cnode, "skip");
debugger->PostExecuteNode(cnode, last_kernel);
}
#endif
void DebugActor::Debug(const AnfNodePtr &node, const KernelLaunchInfo *launch_info_,
const DeviceContext *device_context, OpContext<DeviceTensor> *op_context, const AID *from_aid) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(device_context);
MS_EXCEPTION_IF_NULL(op_context);
MS_EXCEPTION_IF_NULL(from_aid);
// todo debug.
MS_LOG(INFO) << "DebugActor is called";
#ifdef ENABLE_GPU
if (node->isa<CNode>()) {
const auto &cnode = node->cast<CNodePtr>();
auto debugger = Debugger::GetInstance();
if (debugger) {
std::string kernel_name = cnode->fullname_with_scope();
MS_LOG(INFO) << "kernel_name is " << kernel_name;
debugger->SetCurNode(kernel_name);
bool read_data = CheckReadData(cnode);
if (read_data) {
ReadDataAndDump(cnode, launch_info_, exec_order_);
}
}
exec_order_ += 1;
}
@ -151,13 +185,14 @@ void DebugActor::DebugOnStepEnd(OpContext<DeviceTensor> *op_context, const AID *
MS_EXCEPTION_IF_NULL(op_context);
MS_EXCEPTION_IF_NULL(from_aid);
// todo debug.
MS_LOG(INFO) << "DebugActor::DebugOnStepEnd is called";
#ifdef ENABLE_GPU
auto debugger = Debugger::GetInstance();
if (debugger) {
debugger->Debugger::UpdateStepNumGPU();
debugger->Debugger::LoadParametersAndConst();
// Reset exec_order for the next step
exec_order_ = 0;
debugger->Debugger::PostExecuteGraphDebugger();
}
#endif
// Call back to the from actor to process after debug finished.

View File

@ -304,16 +304,15 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic
}
graph->set_is_all_nop_node(opt::IsAllNopNode(graph.get()));
#ifdef ENABLE_DEBUGGER
auto debugger = Debugger::GetInstance();
debugger->DumpInGraphCompiler(graph);
#endif
MS_EXCEPTION_IF_NULL(session_);
session_->InitAllBucket(graph, device_context);
session_->SetSummaryNodes(graph.get());
SetSummaryNodesRefCount(graph.get());
#ifdef ENABLE_DEBUGGER
auto debugger = Debugger::GetInstance();
debugger->DumpInGraphCompiler(graph);
if (debugger && debugger->DebuggerBackendEnabled()) {
debugger->LoadGraphs(graph);
}

View File

@ -606,12 +606,6 @@ void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args,
if (!runtime::GraphScheduler::GetInstance().Run(actor_set)) {
MS_LOG(EXCEPTION) << "The actor runs failed, actor name: " << actor_set->name_;
}
// Debugger post-execute graph.
#ifdef ENABLE_DEBUGGER
if (Debugger::GetInstance()->DebuggerBackendEnabled()) {
Debugger::GetInstance()->PostExecuteGraphDebugger(graph_compiler_info.graphs_);
}
#endif
// Sync device stream.
const auto &first_device_context = graph_compiler_info.device_contexts_[0];