!19252 add dump pb and nccl protected
Merge pull request !19252 from limingqi107/bug_fix3
This commit is contained in:
commit
d0bab79fde
|
@ -279,6 +279,9 @@ void DynamicMemPoolBestFit::ReleaseDeviceRes() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
global_mem_block_list_.clear();
|
||||||
|
global_idle_mem_buf_map_.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
void DynamicMemPoolBestFit::DumpDynamicMemPoolInfo() {
|
void DynamicMemPoolBestFit::DumpDynamicMemPoolInfo() {
|
||||||
|
|
|
@ -32,6 +32,10 @@ ncclUniqueId NCCLWrapper::nccl_unique_id() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
void NCCLWrapper::InitNCCLComm() {
|
void NCCLWrapper::InitNCCLComm() {
|
||||||
|
if (comm_init_done_) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
for (auto group : group_info_) {
|
for (auto group : group_info_) {
|
||||||
std::string group_name = group.first;
|
std::string group_name = group.first;
|
||||||
NcclGroupInfo group_info = group.second;
|
NcclGroupInfo group_info = group.second;
|
||||||
|
|
|
@ -24,6 +24,7 @@
|
||||||
#include "ir/tensor.h"
|
#include "ir/tensor.h"
|
||||||
#include "backend/optimizer/common/helper.h"
|
#include "backend/optimizer/common/helper.h"
|
||||||
#include "base/base_ref_utils.h"
|
#include "base/base_ref_utils.h"
|
||||||
|
#include "debug/dump_proto.h"
|
||||||
#ifdef ENABLE_DEBUGGER
|
#ifdef ENABLE_DEBUGGER
|
||||||
#include "debug/debugger/debugger.h"
|
#include "debug/debugger/debugger.h"
|
||||||
#endif
|
#endif
|
||||||
|
@ -297,6 +298,14 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic
|
||||||
auto &json_parser = DumpJsonParser::GetInstance();
|
auto &json_parser = DumpJsonParser::GetInstance();
|
||||||
json_parser.Parse();
|
json_parser.Parse();
|
||||||
|
|
||||||
|
const auto &ms_context = MsContext::GetInstance();
|
||||||
|
MS_EXCEPTION_IF_NULL(ms_context);
|
||||||
|
bool save_graphs = ms_context->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
|
||||||
|
// Dump .pb graph before graph optimization.
|
||||||
|
if (save_graphs) {
|
||||||
|
DumpIRProto(graph, "before_opt_" + std::to_string(graph->graph_id()));
|
||||||
|
}
|
||||||
|
|
||||||
// Execute optimization pass.
|
// Execute optimization pass.
|
||||||
auto outputs_before_optimizer = AnfAlgo::GetAllOutputWithIndex(graph->output());
|
auto outputs_before_optimizer = AnfAlgo::GetAllOutputWithIndex(graph->output());
|
||||||
device_context->OptimizeGraph(graph);
|
device_context->OptimizeGraph(graph);
|
||||||
|
@ -308,8 +317,6 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic
|
||||||
// 'KernelMod' is real executive object of kernel.
|
// 'KernelMod' is real executive object of kernel.
|
||||||
device_context->CreateKernel(graph->execution_order());
|
device_context->CreateKernel(graph->execution_order());
|
||||||
|
|
||||||
const auto &ms_context = MsContext::GetInstance();
|
|
||||||
MS_EXCEPTION_IF_NULL(ms_context);
|
|
||||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
|
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
|
||||||
// Create device address for all anf nodes of graph.
|
// Create device address for all anf nodes of graph.
|
||||||
CreateDeviceAddress(graph, device_context);
|
CreateDeviceAddress(graph, device_context);
|
||||||
|
@ -322,6 +329,12 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic
|
||||||
|
|
||||||
session_->SetSummaryNodes(graph.get());
|
session_->SetSummaryNodes(graph.get());
|
||||||
SetSummaryNodesRefCount(graph.get());
|
SetSummaryNodesRefCount(graph.get());
|
||||||
|
|
||||||
|
// Dump .pb graph after graph optimization.
|
||||||
|
if (save_graphs) {
|
||||||
|
DumpIRProto(graph, "after_opt_" + std::to_string(graph->graph_id()));
|
||||||
|
}
|
||||||
|
|
||||||
#ifdef ENABLE_DEBUGGER
|
#ifdef ENABLE_DEBUGGER
|
||||||
auto debugger = Debugger::GetInstance();
|
auto debugger = Debugger::GetInstance();
|
||||||
debugger->DumpInGraphCompiler(graph);
|
debugger->DumpInGraphCompiler(graph);
|
||||||
|
@ -329,6 +342,8 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic
|
||||||
debugger->LoadGraphs(graph);
|
debugger->LoadGraphs(graph);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
session_->DumpGraph(graph);
|
||||||
return graph->graph_id();
|
return graph->graph_id();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1962,22 +1962,10 @@ void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vector<Ke
|
||||||
// Link from copy actor to kernel actor users.
|
// Link from copy actor to kernel actor users.
|
||||||
for (auto &output_contorl : kernel_actor->output_control_arrows_) {
|
for (auto &output_contorl : kernel_actor->output_control_arrows_) {
|
||||||
copy_actor->output_control_arrows_.emplace_back(output_contorl);
|
copy_actor->output_control_arrows_.emplace_back(output_contorl);
|
||||||
auto to_actor = FetchActor(output_contorl.Name());
|
|
||||||
MS_EXCEPTION_IF_NULL(to_actor);
|
|
||||||
if (output_contorl.Name().find("_LoopCountActor") != string::npos) {
|
|
||||||
auto real_to_actor = dynamic_cast<LoopCountActor *>(to_actor);
|
|
||||||
MS_EXCEPTION_IF_NULL(real_to_actor);
|
|
||||||
real_to_actor->input_controls_num_++;
|
|
||||||
} else if (output_contorl.Name().find("copy_from") != string::npos) {
|
|
||||||
auto real_to_actor = dynamic_cast<CopyActor *>(to_actor);
|
|
||||||
MS_EXCEPTION_IF_NULL(real_to_actor);
|
|
||||||
real_to_actor->input_controls_num_++;
|
|
||||||
} else {
|
|
||||||
auto real_to_actor = dynamic_cast<KernelActor *>(to_actor);
|
|
||||||
MS_EXCEPTION_IF_NULL(real_to_actor);
|
|
||||||
real_to_actor->input_controls_num_++;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
// Move the control arrows from kernel actor to kernel actor users.
|
||||||
|
kernel_actor->output_control_arrows_.clear();
|
||||||
|
|
||||||
// Link from kernel actor to copy actor.
|
// Link from kernel actor to copy actor.
|
||||||
kernel_actor->output_control_arrows_.emplace_back(copy_actor->GetAID());
|
kernel_actor->output_control_arrows_.emplace_back(copy_actor->GetAID());
|
||||||
copy_actor->input_controls_num_++;
|
copy_actor->input_controls_num_++;
|
||||||
|
|
Loading…
Reference in New Issue