forked from mindspore-Ecosystem/mindspore
!15179 update FindPrimalJPair in dunctor.cc
From: @huangbingjian Reviewed-by: @ginfung,@zh_qh Signed-off-by: @zh_qh
This commit is contained in:
commit
b065554e92
|
@ -819,11 +819,58 @@ bool DFunctor::AllReferencesStopped(const CNodePtr &node) {
|
|||
return true;
|
||||
}
|
||||
|
||||
CNodePtr GetJUser(const NodeUsersMap &node_user_map, const CNodePtr &cnode, int index) {
|
||||
auto it = node_user_map.find(cnode);
|
||||
if (it == node_user_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "J CNode not used {" << cnode->DebugString(2) << "/" << index << "}";
|
||||
}
|
||||
auto &j_users = it->second;
|
||||
auto size = j_users.size();
|
||||
if (size != 1) {
|
||||
MS_LOG(EXCEPTION) << "Wrong J CNode use size " << size << " {" << cnode->DebugString(2) << "/" << index << "}";
|
||||
}
|
||||
return j_users.begin()->first->cast<CNodePtr>();
|
||||
}
|
||||
|
||||
CNodePtr GetPrimalUser(const CNodePtr &j_user, const std::map<FuncGraphPtr, std::vector<CNodePtr>> &primal_map) {
|
||||
// Check if J operation has relevant primal call in the same graph.
|
||||
auto graph = j_user->func_graph();
|
||||
auto iter = primal_map.find(graph);
|
||||
if (iter == primal_map.end()) {
|
||||
MS_LOG(WARNING) << "J operation has no relevant primal call in the same graph. Func graph: " << graph->ToString()
|
||||
<< ", J user: " << j_user->DebugString();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Check if there is only one primal call corresponding to the specified j user.
|
||||
auto primal_users = iter->second;
|
||||
if (primal_users.size() != 1) {
|
||||
MS_LOG(WARNING) << "It is recommended to call the forward network only once.";
|
||||
MS_LOG(INFO) << "There is " << primal_users.size()
|
||||
<< " primal calls for same J operation in the same graph. Func graph: " << graph->ToString()
|
||||
<< ", J operation: " << j_user->DebugString() << ", Primal call: ";
|
||||
size_t count = 0;
|
||||
for (const auto &user : primal_users) {
|
||||
MS_LOG(INFO) << "[ " << ++count << " ] : " << user->DebugString(2) << ", trace: " << trace::DumpSourceLines(user);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Check input size.
|
||||
auto primal_user = primal_users[0];
|
||||
if (primal_user->size() != j_user->size()) {
|
||||
MS_LOG(WARNING) << "Input size incorrect, the input size of primal " << primal_user->DebugString() << " is "
|
||||
<< primal_user->size() << ", and J user " << j_user->DebugString() << " is " << j_user->size();
|
||||
return nullptr;
|
||||
}
|
||||
return primal_user;
|
||||
}
|
||||
|
||||
static std::vector<std::pair<CNodePtr, CNodePtr>> FindPrimalJPair(const FuncGraphManagerPtr &manager,
|
||||
const FuncGraphPtr &primal_graph) {
|
||||
std::vector<std::pair<CNodePtr, CNodePtr>> primal_j_pair;
|
||||
std::map<FuncGraphPtr, CNodePtr> primal_users_map;
|
||||
auto &node_user_map = manager->node_users();
|
||||
std::map<FuncGraphPtr, std::vector<CNodePtr>> primal_map;
|
||||
const auto &node_user_map = manager->node_users();
|
||||
// Search primal graph user cnodes.
|
||||
for (auto &entry : primal_graph->func_graph_cnodes_index()) {
|
||||
auto cnode = entry.first->first->cast<CNodePtr>();
|
||||
|
@ -832,47 +879,26 @@ static std::vector<std::pair<CNodePtr, CNodePtr>> FindPrimalJPair(const FuncGrap
|
|||
// To find real calling.
|
||||
auto fg = cnode->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
if (primal_users_map.find(fg) != primal_users_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "The forward network is only allowed to be called once. Func graph: " << fg->ToString()
|
||||
<< ", cnode: " << cnode->DebugString() << ", trace: " << trace::DumpSourceLines(cnode);
|
||||
auto iter = primal_map.find(fg);
|
||||
if (iter != primal_map.end()) {
|
||||
iter->second.push_back(cnode);
|
||||
continue;
|
||||
}
|
||||
primal_users_map[fg] = cnode;
|
||||
primal_map[fg] = {cnode};
|
||||
} else if (IsPrimitive(cnode->inputs().at(0), prim::kPrimJ)) {
|
||||
// To find J user.
|
||||
auto it = node_user_map.find(cnode);
|
||||
if (it == node_user_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "J CNode not used {" << cnode->DebugString(2) << "/" << index << "}";
|
||||
}
|
||||
auto &j_users = it->second;
|
||||
auto size = j_users.size();
|
||||
if (size != 1) {
|
||||
MS_LOG(EXCEPTION) << "Wrong J CNode use size " << size << " {" << cnode->DebugString(2) << "/" << index << "}";
|
||||
}
|
||||
auto j_user = j_users.begin()->first->cast<CNodePtr>();
|
||||
auto j_user = GetJUser(node_user_map, cnode, index);
|
||||
primal_j_pair.push_back({nullptr, j_user});
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &[primal_user, j_user] : primal_j_pair) {
|
||||
// Check if J operation has relevant primal call in the same graph
|
||||
auto graph = j_user->func_graph();
|
||||
auto iter = primal_users_map.find(graph);
|
||||
if (iter == primal_users_map.end()) {
|
||||
MS_LOG(WARNING) << "J operation has no relevant primal call in the same graph. Func graph: " << graph->ToString()
|
||||
<< ", J user: " << j_user->DebugString();
|
||||
continue;
|
||||
auto primal = GetPrimalUser(j_user, primal_map);
|
||||
if (primal != nullptr) {
|
||||
MS_LOG(DEBUG) << "Primal_J pair is found, where primal is: " << primal->DebugString()
|
||||
<< " and J user is: " << j_user->DebugString();
|
||||
primal_user = primal;
|
||||
}
|
||||
// Check input size.
|
||||
auto primal = iter->second;
|
||||
if (primal->size() != j_user->size()) {
|
||||
MS_LOG(WARNING) << "Input size incorrect, the input size of primal " << primal->DebugString() << " is "
|
||||
<< primal->size() << ", and J user " << j_user->DebugString() << " is " << j_user->size();
|
||||
continue;
|
||||
}
|
||||
|
||||
primal_user = primal;
|
||||
MS_LOG(DEBUG) << "Primal_J pair is found, where primal is: " << primal->DebugString()
|
||||
<< " and J user is: " << j_user->DebugString();
|
||||
}
|
||||
return primal_j_pair;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue