!15179 update FindPrimalJPair in dunctor.cc

From: @huangbingjian
Reviewed-by: @ginfung,@zh_qh
Signed-off-by: @zh_qh
This commit is contained in:
mindspore-ci-bot 2021-04-15 20:48:01 +08:00 committed by Gitee
commit b065554e92
1 changed files with 60 additions and 34 deletions

View File

@ -819,11 +819,58 @@ bool DFunctor::AllReferencesStopped(const CNodePtr &node) {
return true;
}
CNodePtr GetJUser(const NodeUsersMap &node_user_map, const CNodePtr &cnode, int index) {
auto it = node_user_map.find(cnode);
if (it == node_user_map.end()) {
MS_LOG(EXCEPTION) << "J CNode not used {" << cnode->DebugString(2) << "/" << index << "}";
}
auto &j_users = it->second;
auto size = j_users.size();
if (size != 1) {
MS_LOG(EXCEPTION) << "Wrong J CNode use size " << size << " {" << cnode->DebugString(2) << "/" << index << "}";
}
return j_users.begin()->first->cast<CNodePtr>();
}
CNodePtr GetPrimalUser(const CNodePtr &j_user, const std::map<FuncGraphPtr, std::vector<CNodePtr>> &primal_map) {
// Check if J operation has relevant primal call in the same graph.
auto graph = j_user->func_graph();
auto iter = primal_map.find(graph);
if (iter == primal_map.end()) {
MS_LOG(WARNING) << "J operation has no relevant primal call in the same graph. Func graph: " << graph->ToString()
<< ", J user: " << j_user->DebugString();
return nullptr;
}
// Check if there is only one primal call corresponding to the specified j user.
auto primal_users = iter->second;
if (primal_users.size() != 1) {
MS_LOG(WARNING) << "It is recommended to call the forward network only once.";
MS_LOG(INFO) << "There is " << primal_users.size()
<< " primal calls for same J operation in the same graph. Func graph: " << graph->ToString()
<< ", J operation: " << j_user->DebugString() << ", Primal call: ";
size_t count = 0;
for (const auto &user : primal_users) {
MS_LOG(INFO) << "[ " << ++count << " ] : " << user->DebugString(2) << ", trace: " << trace::DumpSourceLines(user);
}
return nullptr;
}
// Check input size.
auto primal_user = primal_users[0];
if (primal_user->size() != j_user->size()) {
MS_LOG(WARNING) << "Input size incorrect, the input size of primal " << primal_user->DebugString() << " is "
<< primal_user->size() << ", and J user " << j_user->DebugString() << " is " << j_user->size();
return nullptr;
}
return primal_user;
}
static std::vector<std::pair<CNodePtr, CNodePtr>> FindPrimalJPair(const FuncGraphManagerPtr &manager,
const FuncGraphPtr &primal_graph) {
std::vector<std::pair<CNodePtr, CNodePtr>> primal_j_pair;
std::map<FuncGraphPtr, CNodePtr> primal_users_map;
auto &node_user_map = manager->node_users();
std::map<FuncGraphPtr, std::vector<CNodePtr>> primal_map;
const auto &node_user_map = manager->node_users();
// Search primal graph user cnodes.
for (auto &entry : primal_graph->func_graph_cnodes_index()) {
auto cnode = entry.first->first->cast<CNodePtr>();
@ -832,47 +879,26 @@ static std::vector<std::pair<CNodePtr, CNodePtr>> FindPrimalJPair(const FuncGrap
// To find real calling.
auto fg = cnode->func_graph();
MS_EXCEPTION_IF_NULL(fg);
if (primal_users_map.find(fg) != primal_users_map.end()) {
MS_LOG(EXCEPTION) << "The forward network is only allowed to be called once. Func graph: " << fg->ToString()
<< ", cnode: " << cnode->DebugString() << ", trace: " << trace::DumpSourceLines(cnode);
auto iter = primal_map.find(fg);
if (iter != primal_map.end()) {
iter->second.push_back(cnode);
continue;
}
primal_users_map[fg] = cnode;
primal_map[fg] = {cnode};
} else if (IsPrimitive(cnode->inputs().at(0), prim::kPrimJ)) {
// To find J user.
auto it = node_user_map.find(cnode);
if (it == node_user_map.end()) {
MS_LOG(EXCEPTION) << "J CNode not used {" << cnode->DebugString(2) << "/" << index << "}";
}
auto &j_users = it->second;
auto size = j_users.size();
if (size != 1) {
MS_LOG(EXCEPTION) << "Wrong J CNode use size " << size << " {" << cnode->DebugString(2) << "/" << index << "}";
}
auto j_user = j_users.begin()->first->cast<CNodePtr>();
auto j_user = GetJUser(node_user_map, cnode, index);
primal_j_pair.push_back({nullptr, j_user});
}
}
for (auto &[primal_user, j_user] : primal_j_pair) {
// Check if J operation has relevant primal call in the same graph
auto graph = j_user->func_graph();
auto iter = primal_users_map.find(graph);
if (iter == primal_users_map.end()) {
MS_LOG(WARNING) << "J operation has no relevant primal call in the same graph. Func graph: " << graph->ToString()
<< ", J user: " << j_user->DebugString();
continue;
auto primal = GetPrimalUser(j_user, primal_map);
if (primal != nullptr) {
MS_LOG(DEBUG) << "Primal_J pair is found, where primal is: " << primal->DebugString()
<< " and J user is: " << j_user->DebugString();
primal_user = primal;
}
// Check input size.
auto primal = iter->second;
if (primal->size() != j_user->size()) {
MS_LOG(WARNING) << "Input size incorrect, the input size of primal " << primal->DebugString() << " is "
<< primal->size() << ", and J user " << j_user->DebugString() << " is " << j_user->size();
continue;
}
primal_user = primal;
MS_LOG(DEBUG) << "Primal_J pair is found, where primal is: " << primal->DebugString()
<< " and J user is: " << j_user->DebugString();
}
return primal_j_pair;
}