Assign atomic_clean to different subgraphs.

This commit is contained in:
linqingke 2021-07-13 20:12:25 +08:00
parent d1f3434f4d
commit a90f9a98c5
1 changed files with 30 additions and 8 deletions

View File

@ -163,10 +163,11 @@ static void AddTbeClearZeroNode(mindspore::session::KernelGraph *const kernel_gr
}
static void AddFusionTbeClearZeroNode(mindspore::session::KernelGraph *const kernel_graph,
const mindspore::CNodePtr &stream_node,
const mindspore::CNodePtr &first_clear_node,
const std::vector<AnfNodePtr> &fusion_clear_inputs,
const std::vector<size_t> &clean_size_list,
std::vector<mindspore::CNodePtr> *new_nodes) {
MS_EXCEPTION_IF_NULL(first_clear_node);
auto clear_zero_prim = std::make_shared<Primitive>(kAtomicAddrCleanOpName);
MS_EXCEPTION_IF_NULL(clear_zero_prim);
auto new_value_node = NewValueNode(clear_zero_prim);
@ -182,8 +183,13 @@ static void AddFusionTbeClearZeroNode(mindspore::session::KernelGraph *const ker
builder->SetKernelType(KernelType::TBE_KERNEL);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), clear_zero.get());
AnfAlgo::SetNodeAttr(kAttrAtomicAddMemSize, MakeValue(clean_size_list), clear_zero);
AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(stream_node.get()), clear_zero.get());
AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(first_clear_node.get()), clear_zero.get());
auto it = std::find(new_nodes->begin(), new_nodes->end(), first_clear_node);
if (it != new_nodes->end()) {
new_nodes->insert(it, clear_zero);
} else {
new_nodes->insert(new_nodes->begin(), clear_zero);
}
}
static bool IsAtomicNode(const CNodePtr &kernel_node) {
@ -299,10 +305,24 @@ std::map<AnfNodePtr, std::vector<size_t>> GetCommunicationOpInputInfo(
return comm_input_info_map;
}
bool IsNeedClearZeroNodeFusion(const size_t clean_total_num, const mindspore::CNodePtr &first_node,
const mindspore::CNodePtr &current_node) {
if (first_node == nullptr || current_node == nullptr) {
return false;
}
auto first_graph_id = AnfAlgo::GetGraphId(first_node.get());
auto current_graph_id = AnfAlgo::GetGraphId(current_node.get());
if (clean_total_num >= kMaxAttrMemListSize || first_graph_id != current_graph_id) {
return true;
}
return false;
}
static void TbeClearZeroNodeFusion(mindspore::session::KernelGraph *const kernel_graph) {
std::vector<CNodePtr> new_nodes;
std::vector<size_t> clean_size_list;
std::vector<AnfNodePtr> fusion_clear_inputs;
CNodePtr first_node = nullptr;
std::map<AnfNodePtr, std::vector<size_t>> comm_input_info_map = GetCommunicationOpInputInfo(kernel_graph);
for (const auto &anf_node : kernel_graph->execution_order()) {
std::string apply_function_name = AnfAlgo::GetCNodeName(anf_node);
@ -339,12 +359,15 @@ static void TbeClearZeroNodeFusion(mindspore::session::KernelGraph *const kernel
auto clean_sizes = CalCleanZerosSize(anf_node);
if (!clean_sizes.empty()) {
auto clean_total_num = clean_size_list.size() + clean_sizes.size();
if (clean_total_num >= kMaxAttrMemListSize) {
if (IsNeedClearZeroNodeFusion(clean_total_num, first_node, anf_node)) {
// create clean node
auto stream_node = new_nodes.empty() ? anf_node : new_nodes.front();
AddFusionTbeClearZeroNode(kernel_graph, stream_node, fusion_clear_inputs, clean_size_list, &new_nodes);
AddFusionTbeClearZeroNode(kernel_graph, first_node, fusion_clear_inputs, clean_size_list, &new_nodes);
clean_size_list.clear();
fusion_clear_inputs.clear();
first_node = nullptr;
}
if (fusion_clear_inputs.empty()) {
first_node = anf_node;
}
clean_size_list.insert(clean_size_list.end(), clean_sizes.begin(), clean_sizes.end());
fusion_clear_inputs.emplace_back(anf_node);
@ -357,8 +380,7 @@ static void TbeClearZeroNodeFusion(mindspore::session::KernelGraph *const kernel
if (!fusion_clear_inputs.empty() && !clean_size_list.empty()) {
// create clean node
auto stream_node = new_nodes.front();
AddFusionTbeClearZeroNode(kernel_graph, stream_node, fusion_clear_inputs, clean_size_list, &new_nodes);
AddFusionTbeClearZeroNode(kernel_graph, first_node, fusion_clear_inputs, clean_size_list, &new_nodes);
}
kernel_graph->set_execution_order(new_nodes);
}