forked from mindspore-Ecosystem/mindspore
!14678 update findPrimalJPair
From: @huangbingjian Reviewed-by: @zh_qh Signed-off-by: @zh_qh
This commit is contained in:
commit
19d43c323a
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -822,10 +822,10 @@ bool DFunctor::AllReferencesStopped(const CNodePtr &node) {
|
|||
return true;
|
||||
}
|
||||
|
||||
static std::pair<CNodePtr, CNodePtr> FindPrimalJPair(const FuncGraphManagerPtr &manager,
|
||||
const FuncGraphPtr &primal_graph) {
|
||||
CNodePtr primal_user = nullptr;
|
||||
CNodePtr j_user = nullptr;
|
||||
static std::vector<std::pair<CNodePtr, CNodePtr>> FindPrimalJPair(const FuncGraphManagerPtr &manager,
|
||||
const FuncGraphPtr &primal_graph) {
|
||||
std::vector<std::pair<CNodePtr, CNodePtr>> primal_j_pair;
|
||||
std::map<FuncGraphPtr, CNodePtr> primal_users_map;
|
||||
auto &node_user_map = manager->node_users();
|
||||
// Search primal graph user cnodes.
|
||||
for (auto &entry : primal_graph->func_graph_cnodes_index()) {
|
||||
|
@ -833,7 +833,13 @@ static std::pair<CNodePtr, CNodePtr> FindPrimalJPair(const FuncGraphManagerPtr &
|
|||
auto index = entry.first->second;
|
||||
if (index == 0) {
|
||||
// To find real calling.
|
||||
primal_user = cnode;
|
||||
auto fg = cnode->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
if (primal_users_map.find(fg) != primal_users_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "The forward network is only allowed to be called once. Func graph: " << fg->ToString()
|
||||
<< ", cnode: " << cnode->DebugString() << ", trace: " << trace::DumpSourceLines(cnode);
|
||||
}
|
||||
primal_users_map[fg] = cnode;
|
||||
} else if (IsPrimitive(cnode->inputs().at(0), prim::kPrimJ)) {
|
||||
// To find J user.
|
||||
auto it = node_user_map.find(cnode);
|
||||
|
@ -845,13 +851,33 @@ static std::pair<CNodePtr, CNodePtr> FindPrimalJPair(const FuncGraphManagerPtr &
|
|||
if (size != 1) {
|
||||
MS_LOG(EXCEPTION) << "Wrong J CNode use size " << size << " {" << cnode->DebugString(2) << "/" << index << "}";
|
||||
}
|
||||
j_user = j_users.begin()->first->cast<CNodePtr>();
|
||||
}
|
||||
if (j_user != nullptr && primal_user != nullptr) {
|
||||
break;
|
||||
auto j_user = j_users.begin()->first->cast<CNodePtr>();
|
||||
primal_j_pair.push_back({nullptr, j_user});
|
||||
}
|
||||
}
|
||||
return {primal_user, j_user};
|
||||
|
||||
for (auto &[primal_user, j_user] : primal_j_pair) {
|
||||
// Check if J operation has relevant primal call in the same graph
|
||||
auto graph = j_user->func_graph();
|
||||
auto iter = primal_users_map.find(graph);
|
||||
if (iter == primal_users_map.end()) {
|
||||
MS_LOG(WARNING) << "J operation has no relevant primal call in the same graph. Func graph: " << graph->ToString()
|
||||
<< ", J user: " << j_user->DebugString();
|
||||
continue;
|
||||
}
|
||||
// Check input size.
|
||||
auto primal = iter->second;
|
||||
if (primal->size() != j_user->size()) {
|
||||
MS_LOG(WARNING) << "Input size incorrect, the input size of primal " << primal->DebugString() << " is "
|
||||
<< primal->size() << ", and J user " << j_user->DebugString() << " is " << j_user->size();
|
||||
continue;
|
||||
}
|
||||
|
||||
primal_user = primal;
|
||||
MS_LOG(DEBUG) << "Primal_J pair is found, where primal is: " << primal->DebugString()
|
||||
<< " and J user is: " << j_user->DebugString();
|
||||
}
|
||||
return primal_j_pair;
|
||||
}
|
||||
|
||||
static void RemovePrimalUpdateStates(const FuncGraphManagerPtr &manager, const CNodePtr &primal_call) {
|
||||
|
@ -915,40 +941,37 @@ void DFunctor::EliminatePrimalGraph() {
|
|||
// Find primal user and paired J user cnodes.
|
||||
auto manager = primal_graph_->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto [primal_user, j_user] = FindPrimalJPair(manager, primal_graph_);
|
||||
if (primal_user == nullptr || j_user == nullptr) {
|
||||
// Skip if one of them not found.
|
||||
return;
|
||||
}
|
||||
// Check input size.
|
||||
if (primal_user->size() != j_user->size()) {
|
||||
MS_LOG(WARNING) << "Input size incorrect, primal:" << primal_user->DebugString()
|
||||
<< " juser:" << j_user->DebugString();
|
||||
return;
|
||||
}
|
||||
// Replace primal graph with k graph.
|
||||
auto k_vnode = NewValueNode(k_graph_);
|
||||
auto primal_abs = primal_user->abstract();
|
||||
primal_user->set_input(0, k_vnode);
|
||||
primal_user->set_abstract(j_user->abstract());
|
||||
auto prim_j_pair = FindPrimalJPair(manager, primal_graph_);
|
||||
for (auto &[primal_user, j_user] : prim_j_pair) {
|
||||
if (primal_user == nullptr || j_user == nullptr) {
|
||||
// Skip if one of them not found.
|
||||
return;
|
||||
}
|
||||
|
||||
// If both inputs are same except monads, we copy primal monad args to k graph
|
||||
// so that they can be combined in CSE (common subexpression elimination) pass.
|
||||
const bool has_monad = CopyMonadArguments(primal_user, j_user);
|
||||
// Remove the UpdateState nodes after primal_user if need.
|
||||
if (has_monad) {
|
||||
RemovePrimalUpdateStates(manager, primal_user);
|
||||
}
|
||||
// Replace primal graph with k graph.
|
||||
auto k_vnode = NewValueNode(k_graph_);
|
||||
auto primal_abs = primal_user->abstract();
|
||||
primal_user->set_input(0, k_vnode);
|
||||
primal_user->set_abstract(j_user->abstract());
|
||||
|
||||
// Insert tuple_getitem after primal user cnode.
|
||||
auto construct_wrapper = primal_user->func_graph();
|
||||
auto tuple_getitem = NewValueNode(prim::kPrimTupleGetItem);
|
||||
auto imm0 = std::make_shared<Int64Imm>(0);
|
||||
auto idx0 = NewValueNode(SizeToLong(0));
|
||||
idx0->set_abstract(std::make_shared<abstract::AbstractScalar>(imm0));
|
||||
auto getitem0 = construct_wrapper->NewCNode({tuple_getitem, primal_user, idx0});
|
||||
getitem0->set_abstract(primal_abs);
|
||||
manager->Replace(primal_user, getitem0);
|
||||
// If both inputs are same except monads, we copy primal monad args to k graph
|
||||
// so that they can be combined in CSE (common subexpression elimination) pass.
|
||||
const bool has_monad = CopyMonadArguments(primal_user, j_user);
|
||||
// Remove the UpdateState nodes after primal_user if need.
|
||||
if (has_monad) {
|
||||
RemovePrimalUpdateStates(manager, primal_user);
|
||||
}
|
||||
|
||||
// Insert tuple_getitem after primal user cnode.
|
||||
auto construct_wrapper = primal_user->func_graph();
|
||||
auto tuple_getitem = NewValueNode(prim::kPrimTupleGetItem);
|
||||
auto imm0 = std::make_shared<Int64Imm>(0);
|
||||
auto idx0 = NewValueNode(SizeToLong(0));
|
||||
idx0->set_abstract(std::make_shared<abstract::AbstractScalar>(imm0));
|
||||
auto getitem0 = construct_wrapper->NewCNode({tuple_getitem, primal_user, idx0});
|
||||
getitem0->set_abstract(primal_abs);
|
||||
manager->Replace(primal_user, getitem0);
|
||||
}
|
||||
}
|
||||
} // namespace ad
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue