fix bug of node users not update

This commit is contained in:
mengyuanli 2021-09-02 21:35:44 +08:00
parent a93031498c
commit 8032d159f5
2 changed files with 22 additions and 0 deletions

View File

@ -275,6 +275,21 @@ int MindIRControlFlowAdjust::InsertPartialFusionForRawCall(const std::set<FuncGr
return RET_OK; return RET_OK;
} }
int MindIRControlFlowAdjust::ResetFuncGraph(const FuncGraphPtr &fg, std::set<FuncGraphPtr> all_func_graphs) {
MS_CHECK_TRUE_MSG(fg != nullptr, RET_NULL_PTR, "fg is nullptr.");
auto manager = fg->manager();
MS_CHECK_TRUE_MSG(manager != nullptr, RET_NULL_PTR, "manager is nullptr.");
manager->Clear();
manager->AddFuncGraph(fg, true);
for (auto &item : all_func_graphs) {
if (item == fg) {
continue;
}
manager->AddFuncGraph(item);
}
return RET_OK;
}
bool MindIRControlFlowAdjust::Run(const FuncGraphPtr &func_graph) { bool MindIRControlFlowAdjust::Run(const FuncGraphPtr &func_graph) {
if (this->fmk_type_ != FmkType::kFmkTypeMs) { if (this->fmk_type_ != FmkType::kFmkTypeMs) {
MS_LOG(INFO) << "The framework type of model should be MindIR."; MS_LOG(INFO) << "The framework type of model should be MindIR.";
@ -297,6 +312,12 @@ bool MindIRControlFlowAdjust::Run(const FuncGraphPtr &func_graph) {
MS_LOG(ERROR) << "AddAfterFgForInlinedFg failed."; MS_LOG(ERROR) << "AddAfterFgForInlinedFg failed.";
return false; return false;
} }
ret = ResetFuncGraph(func_graph, all_func_graphs);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ResetFuncGraph failed.";
return false;
}
if (status_ != RET_OK) { if (status_ != RET_OK) {
return false; return false;
} }

View File

@ -43,6 +43,7 @@ class MindIRControlFlowAdjust {
int AddAfterFgForInlinedFg(const std::set<FuncGraphPtr> &all_func_graphs, const FuncGraphPtr &main_fg); int AddAfterFgForInlinedFg(const std::set<FuncGraphPtr> &all_func_graphs, const FuncGraphPtr &main_fg);
int InsertPartialFusionForRawCall(const std::set<FuncGraphPtr> &all_func_graphs); int InsertPartialFusionForRawCall(const std::set<FuncGraphPtr> &all_func_graphs);
CNodePtr GetMainFgSwitchNode(const FuncGraphPtr &fg); CNodePtr GetMainFgSwitchNode(const FuncGraphPtr &fg);
int ResetFuncGraph(const FuncGraphPtr &fg, std::set<FuncGraphPtr> all_func_graphs);
private: private:
FmkType fmk_type_ = FmkType::kFmkTypeMs; FmkType fmk_type_ = FmkType::kFmkTypeMs;