forked from mindspore-Ecosystem/mindspore
fix confusionmulgrad fusion pass cannot work
This commit is contained in:
parent
69f6a1d6bd
commit
230e77f923
|
@ -72,6 +72,38 @@ AnfNodePtr GetMul0(const FuncGraphPtr &graph, const AnfNodePtr &input2, const An
|
|||
}
|
||||
return mul0;
|
||||
}
|
||||
|
||||
bool QuitFusion(const FuncGraphPtr &graph, const AnfNodePtr &mul0_anf, const AnfNodePtr &reduce_sum) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(mul0_anf);
|
||||
MS_EXCEPTION_IF_NULL(reduce_sum);
|
||||
if (!mul0_anf->isa<CNode>()) {
|
||||
return true;
|
||||
}
|
||||
auto mul0 = mul0_anf->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(mul0);
|
||||
|
||||
// when network is _VirtualDatasetCell, quit fusion
|
||||
if (mul0->fullname_with_scope().find("network-_VirtualDatasetCell") != std::string::npos) {
|
||||
return true;
|
||||
}
|
||||
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
if (manager->node_users().find(reduce_sum) == manager->node_users().end()) {
|
||||
MS_LOG(EXCEPTION) << "node has no output in manager";
|
||||
}
|
||||
const AnfNodeIndexSet &outputs_set = manager->node_users()[reduce_sum];
|
||||
auto it = std::find_if(outputs_set.begin(), outputs_set.end(), [&mul0](const std::pair<AnfNodePtr, int> &node_index) {
|
||||
return node_index.first == mul0->input(1) || node_index.first == mul0;
|
||||
});
|
||||
if (it != outputs_set.end()) {
|
||||
MS_LOG(INFO) << "ReduceSum's output node is mul0's input or mul0! If do fusion, graph will exist a circle";
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef ConfusionMulGradFusion::DefinePattern() const {
|
||||
|
@ -90,9 +122,6 @@ const AnfNodePtr ConfusionMulGradFusion::Process(const FuncGraphPtr &graph, cons
|
|||
auto reduce_sum = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(reduce_sum);
|
||||
auto mul1 = reduce_sum->input(1);
|
||||
if (mul1->fullname_with_scope().find("bert/encoder") == std::string::npos) {
|
||||
return nullptr;
|
||||
}
|
||||
if (IsUsedByOthers(graph, mul1)) {
|
||||
MS_LOG(INFO) << "Mul1 is used by others, quit fusion!";
|
||||
return nullptr;
|
||||
|
@ -102,6 +131,9 @@ const AnfNodePtr ConfusionMulGradFusion::Process(const FuncGraphPtr &graph, cons
|
|||
MS_LOG(INFO) << "Mul0 do not exist, quit fusion";
|
||||
return nullptr;
|
||||
}
|
||||
if (QuitFusion(graph, mul0, node)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto fusion_node = CreateFusionNode(graph, reduce_sum, mul0, input3);
|
||||
std::vector<AnfNodePtr> fusion_node_outputs;
|
||||
|
|
|
@ -32,11 +32,6 @@ class TestHWOptimizeConfusionMulGradFusion : public BackendCommon {
|
|||
TEST_F(TestHWOptimizeConfusionMulGradFusion, test_fusion) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_confusion_mul_grad_fusion", "before");
|
||||
EXPECT_NE(g, nullptr);
|
||||
auto bert_scope = std::make_shared<Scope>("bert/encoder");
|
||||
for (auto node : TopoSort(g->get_return())) {
|
||||
node->set_scope(bert_scope);
|
||||
}
|
||||
|
||||
std::vector<int> shp{1, 1, 1, 1};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
|
|
Loading…
Reference in New Issue