forked from OSSInnovation/mindspore
!804 Fix ConfusionMulGrad op fusion failed in a testcase
Merge pull request !804 from huanghui/fix-confusionmulgrad-fusion-pass
This commit is contained in:
commit
22ba991fec
|
@ -51,6 +51,7 @@
|
|||
#include "pre_activate/ascend/ir_fusion/derelu_fusion.h"
|
||||
#include "pre_activate/ascend/ir_fusion/batchnorm_to_bninfer.h"
|
||||
#include "pre_activate/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h"
|
||||
#include "pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.h"
|
||||
#include "pre_activate/ascend/format_type/insert_trans_op.h"
|
||||
#include "pre_activate/pass/getitem_tuple.h"
|
||||
#include "pre_activate/pass/optimize_dependence.h"
|
||||
|
@ -104,6 +105,7 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) {
|
|||
ir_fusion_pm->AddPass(std::make_shared<MatmulBiasaddFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AddnFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<DereluFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<ConfusionMulGradFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<TransposeTransDataFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<BatchNorm2BNInfer>());
|
||||
|
|
|
@ -73,13 +73,16 @@ AnfNodePtr GetMul0(const FuncGraphPtr &graph, const AnfNodePtr &input2, const An
|
|||
return mul0;
|
||||
}
|
||||
|
||||
bool QuitFusion(const FuncGraphPtr &graph, const AnfNodePtr &mul0_anf, const AnfNodePtr &reduce_sum) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
bool QuitFusion(const FuncGraphPtr &graph, const AnfNodePtr &mul0_anf, const AnfNodePtr &mul1_anf,
|
||||
const AnfNodePtr &reduce_sum) {
|
||||
MS_EXCEPTION_IF_NULL(mul0_anf);
|
||||
MS_EXCEPTION_IF_NULL(mul1_anf);
|
||||
MS_EXCEPTION_IF_NULL(reduce_sum);
|
||||
if (!mul0_anf->isa<CNode>()) {
|
||||
if (!mul0_anf->isa<CNode>() || !mul1_anf->isa<CNode>()) {
|
||||
return true;
|
||||
}
|
||||
auto mul1 = mul1_anf->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(mul1);
|
||||
auto mul0 = mul0_anf->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(mul0);
|
||||
|
||||
|
@ -88,20 +91,14 @@ bool QuitFusion(const FuncGraphPtr &graph, const AnfNodePtr &mul0_anf, const Anf
|
|||
return true;
|
||||
}
|
||||
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
if (manager->node_users().find(reduce_sum) == manager->node_users().end()) {
|
||||
MS_LOG(EXCEPTION) << "node has no output in manager";
|
||||
}
|
||||
const AnfNodeIndexSet &outputs_set = manager->node_users()[reduce_sum];
|
||||
auto it = std::find_if(outputs_set.begin(), outputs_set.end(), [&mul0](const std::pair<AnfNodePtr, int> &node_index) {
|
||||
return node_index.first == mul0->input(1) || node_index.first == mul0;
|
||||
});
|
||||
if (it != outputs_set.end()) {
|
||||
MS_LOG(INFO) << "ReduceSum's output node is mul0's input or mul0! If do fusion, graph will exist a circle";
|
||||
if (IsDepend(graph, mul0->input(1), reduce_sum)) {
|
||||
MS_LOG(INFO) << "mul0->input(1) depends on reduce_sum, quit fusion";
|
||||
return true;
|
||||
}
|
||||
if (IsDepend(graph, mul1->input(1), mul0)) {
|
||||
MS_LOG(INFO) << "mul1->input(1) depends on mul0, quit fusion";
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
|
@ -131,7 +128,7 @@ const AnfNodePtr ConfusionMulGradFusion::Process(const FuncGraphPtr &graph, cons
|
|||
MS_LOG(INFO) << "Mul0 do not exist, quit fusion";
|
||||
return nullptr;
|
||||
}
|
||||
if (QuitFusion(graph, mul0, node)) {
|
||||
if (QuitFusion(graph, mul0, mul1, node)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -18,6 +18,9 @@
|
|||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <deque>
|
||||
#include "utils/utils.h"
|
||||
#include "utils/base_ref.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
|
@ -35,6 +38,56 @@ std::vector<int> Convert2Int(const std::vector<size_t> &v) {
|
|||
return result;
|
||||
}
|
||||
|
||||
bool IsDepend(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node1);
|
||||
MS_EXCEPTION_IF_NULL(node2);
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(graph->get_return());
|
||||
std::map<AnfNodePtr, std::set<AnfNodePtr>> control_depend_map;
|
||||
for (auto &nd : node_list) {
|
||||
if (AnfAlgo::CheckPrimitiveType(nd, prim::kPrimControlDepend)) {
|
||||
auto control_depend = nd->cast<CNodePtr>();
|
||||
auto prior_node = control_depend->input(kControlDependPriorIndex);
|
||||
auto behind_node = control_depend->input(kControlDependBehindIndex);
|
||||
auto it = control_depend_map.find(behind_node);
|
||||
if (it == control_depend_map.end()) {
|
||||
control_depend_map[behind_node] = std::set<AnfNodePtr>{prior_node};
|
||||
} else {
|
||||
it->second.insert(prior_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
FuncGraphManagerPtr manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
|
||||
std::unordered_set<AnfNodePtr> seen_node;
|
||||
std::deque<AnfNodePtr> todo{node1};
|
||||
while (!todo.empty()) {
|
||||
AnfNodePtr node = todo.front();
|
||||
todo.pop_front();
|
||||
if (seen_node.count(node) > 0 || !manager->all_nodes().contains(node)) {
|
||||
continue;
|
||||
}
|
||||
(void)seen_node.insert(node);
|
||||
|
||||
if (node == node2) {
|
||||
return true;
|
||||
}
|
||||
if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto inputs = cnode->inputs();
|
||||
(void)todo.insert(todo.end(), inputs.begin(), inputs.end());
|
||||
}
|
||||
auto it = control_depend_map.find(node);
|
||||
if (it != control_depend_map.end()) {
|
||||
(void)todo.insert(todo.end(), it->second.begin(), it->second.end());
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool UnVisited(const BaseRef &n) {
|
||||
if (utils::isa<AnfNodePtr>(n)) {
|
||||
AnfNodePtr in = utils::cast<AnfNodePtr>(n);
|
||||
|
|
|
@ -111,6 +111,9 @@ enum ConvBn1Output {
|
|||
|
||||
std::vector<int> Convert2Int(const std::vector<size_t> &v);
|
||||
|
||||
// check whether node1 depends on node2 or not
|
||||
bool IsDepend(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2);
|
||||
|
||||
bool UnVisited(const BaseRef &n);
|
||||
|
||||
bool Visited(const BaseRef &n);
|
||||
|
|
Loading…
Reference in New Issue