Assign atomic_clean to different subgraphs.
This commit is contained in:
parent
d1f3434f4d
commit
a90f9a98c5
|
@ -163,10 +163,11 @@ static void AddTbeClearZeroNode(mindspore::session::KernelGraph *const kernel_gr
|
|||
}
|
||||
|
||||
static void AddFusionTbeClearZeroNode(mindspore::session::KernelGraph *const kernel_graph,
|
||||
const mindspore::CNodePtr &stream_node,
|
||||
const mindspore::CNodePtr &first_clear_node,
|
||||
const std::vector<AnfNodePtr> &fusion_clear_inputs,
|
||||
const std::vector<size_t> &clean_size_list,
|
||||
std::vector<mindspore::CNodePtr> *new_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(first_clear_node);
|
||||
auto clear_zero_prim = std::make_shared<Primitive>(kAtomicAddrCleanOpName);
|
||||
MS_EXCEPTION_IF_NULL(clear_zero_prim);
|
||||
auto new_value_node = NewValueNode(clear_zero_prim);
|
||||
|
@ -182,8 +183,13 @@ static void AddFusionTbeClearZeroNode(mindspore::session::KernelGraph *const ker
|
|||
builder->SetKernelType(KernelType::TBE_KERNEL);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), clear_zero.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrAtomicAddMemSize, MakeValue(clean_size_list), clear_zero);
|
||||
AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(stream_node.get()), clear_zero.get());
|
||||
new_nodes->insert(new_nodes->begin(), clear_zero);
|
||||
AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(first_clear_node.get()), clear_zero.get());
|
||||
auto it = std::find(new_nodes->begin(), new_nodes->end(), first_clear_node);
|
||||
if (it != new_nodes->end()) {
|
||||
new_nodes->insert(it, clear_zero);
|
||||
} else {
|
||||
new_nodes->insert(new_nodes->begin(), clear_zero);
|
||||
}
|
||||
}
|
||||
|
||||
static bool IsAtomicNode(const CNodePtr &kernel_node) {
|
||||
|
@ -299,10 +305,24 @@ std::map<AnfNodePtr, std::vector<size_t>> GetCommunicationOpInputInfo(
|
|||
return comm_input_info_map;
|
||||
}
|
||||
|
||||
bool IsNeedClearZeroNodeFusion(const size_t clean_total_num, const mindspore::CNodePtr &first_node,
|
||||
const mindspore::CNodePtr ¤t_node) {
|
||||
if (first_node == nullptr || current_node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto first_graph_id = AnfAlgo::GetGraphId(first_node.get());
|
||||
auto current_graph_id = AnfAlgo::GetGraphId(current_node.get());
|
||||
if (clean_total_num >= kMaxAttrMemListSize || first_graph_id != current_graph_id) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static void TbeClearZeroNodeFusion(mindspore::session::KernelGraph *const kernel_graph) {
|
||||
std::vector<CNodePtr> new_nodes;
|
||||
std::vector<size_t> clean_size_list;
|
||||
std::vector<AnfNodePtr> fusion_clear_inputs;
|
||||
CNodePtr first_node = nullptr;
|
||||
std::map<AnfNodePtr, std::vector<size_t>> comm_input_info_map = GetCommunicationOpInputInfo(kernel_graph);
|
||||
for (const auto &anf_node : kernel_graph->execution_order()) {
|
||||
std::string apply_function_name = AnfAlgo::GetCNodeName(anf_node);
|
||||
|
@ -339,12 +359,15 @@ static void TbeClearZeroNodeFusion(mindspore::session::KernelGraph *const kernel
|
|||
auto clean_sizes = CalCleanZerosSize(anf_node);
|
||||
if (!clean_sizes.empty()) {
|
||||
auto clean_total_num = clean_size_list.size() + clean_sizes.size();
|
||||
if (clean_total_num >= kMaxAttrMemListSize) {
|
||||
if (IsNeedClearZeroNodeFusion(clean_total_num, first_node, anf_node)) {
|
||||
// create clean node
|
||||
auto stream_node = new_nodes.empty() ? anf_node : new_nodes.front();
|
||||
AddFusionTbeClearZeroNode(kernel_graph, stream_node, fusion_clear_inputs, clean_size_list, &new_nodes);
|
||||
AddFusionTbeClearZeroNode(kernel_graph, first_node, fusion_clear_inputs, clean_size_list, &new_nodes);
|
||||
clean_size_list.clear();
|
||||
fusion_clear_inputs.clear();
|
||||
first_node = nullptr;
|
||||
}
|
||||
if (fusion_clear_inputs.empty()) {
|
||||
first_node = anf_node;
|
||||
}
|
||||
clean_size_list.insert(clean_size_list.end(), clean_sizes.begin(), clean_sizes.end());
|
||||
fusion_clear_inputs.emplace_back(anf_node);
|
||||
|
@ -357,8 +380,7 @@ static void TbeClearZeroNodeFusion(mindspore::session::KernelGraph *const kernel
|
|||
|
||||
if (!fusion_clear_inputs.empty() && !clean_size_list.empty()) {
|
||||
// create clean node
|
||||
auto stream_node = new_nodes.front();
|
||||
AddFusionTbeClearZeroNode(kernel_graph, stream_node, fusion_clear_inputs, clean_size_list, &new_nodes);
|
||||
AddFusionTbeClearZeroNode(kernel_graph, first_node, fusion_clear_inputs, clean_size_list, &new_nodes);
|
||||
}
|
||||
kernel_graph->set_execution_order(new_nodes);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue