Fix ub fusion's circle

This commit is contained in:
huanghui 2021-02-24 11:12:58 +08:00
parent 095d7fb877
commit 04fee8a75c
2 changed files with 22 additions and 21 deletions

View File

@ -27,6 +27,7 @@
#include "base/core_ops.h"
#include "runtime/device/kernel_info.h"
#include "utils/ms_context.h"
#include "backend/optimizer/common/helper.h"
namespace mindspore {
namespace opt {
@ -353,6 +354,24 @@ void SetFusionOpRefInfos(session::KernelGraph *kernel_graph, const std::vector<A
}
}
}
void RemoveCircle(const session::KernelGraph &kernel_graph,
std::unordered_map<int64_t, BufferFusionInfo_t> *buffer_fusion_infos) {
MS_EXCEPTION_IF_NULL(buffer_fusion_infos);
for (auto &[fusion_id, fusion_info] : *buffer_fusion_infos) {
bool has_circle = false;
for (auto &inp : fusion_info.inputs_list) {
if (IsDepend(kernel_graph, inp, fusion_info.anf_nodes)) {
has_circle = true;
break;
}
}
if (has_circle) {
buffer_fusion_infos->erase(fusion_id);
}
}
}
} // namespace
void UbPatternFusion::GetBufferFusionInfo(session::KernelGraph *kernel_graph,
@ -361,6 +380,9 @@ void UbPatternFusion::GetBufferFusionInfo(session::KernelGraph *kernel_graph,
GetFusionScopeComputeNodeList(kernel_graph, buffer_fusion_infos);
GetFusionScopeInputNodeList(*kernel_graph, buffer_fusion_infos);
GetFusionScopeOutputNodeList(kernel_graph, buffer_fusion_infos);
// Remove circle which will produce a circle if do fusion
RemoveCircle(*kernel_graph, buffer_fusion_infos);
for (auto &buffer_fusion_info : *buffer_fusion_infos) {
buffer_fusion_info.second.kernel_build_info =
CreateFusionOpKernelInfo(buffer_fusion_info.second.inputs_list, buffer_fusion_info.second.outputs_list);

View File

@ -49,23 +49,6 @@ std::vector<int64_t> Convert2Long(const std::vector<size_t> &v) {
bool IsDepend(const FuncGraph &graph, const AnfNodePtr &node, const std::vector<AnfNodePtr> &nodes) {
MS_EXCEPTION_IF_NULL(node);
std::vector<AnfNodePtr> node_list = TopoSort(graph.get_return());
std::map<AnfNodePtr, std::set<AnfNodePtr>> control_depend_map;
for (auto &nd : node_list) {
MS_EXCEPTION_IF_NULL(nd);
if (AnfAlgo::CheckPrimitiveType(nd, prim::kPrimControlDepend)) {
auto control_depend = nd->cast<CNodePtr>();
auto prior_node = control_depend->input(kControlDependPriorIndex);
auto behind_node = control_depend->input(kControlDependBehindIndex);
auto it = control_depend_map.find(behind_node);
if (it == control_depend_map.end()) {
control_depend_map[behind_node] = std::set<AnfNodePtr>{prior_node};
} else {
it->second.insert(prior_node);
}
}
}
FuncGraphManagerPtr manager = graph.manager();
MS_EXCEPTION_IF_NULL(manager);
@ -88,10 +71,6 @@ bool IsDepend(const FuncGraph &graph, const AnfNodePtr &node, const std::vector<
auto inputs = cnode->inputs();
(void)todo.insert(todo.end(), inputs.begin(), inputs.end());
}
auto it = control_depend_map.find(nd);
if (it != control_depend_map.end()) {
(void)todo.insert(todo.end(), it->second.begin(), it->second.end());
}
}
return false;
}