forked from mindspore-Ecosystem/mindspore
Fix segment input/output caculation
This commit is contained in:
parent
1dc0efbab5
commit
98bab1eb87
|
@ -727,7 +727,8 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph,
|
|||
MS_EXCEPTION_IF_NULL(other_graph_cnode);
|
||||
MS_EXCEPTION_IF_NULL(cnode_inputs);
|
||||
auto origin_inputs = cnode->inputs();
|
||||
bool optimize_depend = IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() >= 3;
|
||||
const bool is_depend = IsPrimitiveCNode(cnode, prim::kPrimDepend);
|
||||
const bool is_updatestate = IsPrimitiveCNode(cnode, prim::kPrimUpdateState);
|
||||
// if has multiple depends,only select first depend as parameter
|
||||
for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) {
|
||||
auto anf = origin_inputs[input_idx];
|
||||
|
@ -736,7 +737,7 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph,
|
|||
if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
|
||||
(void)cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(anf));
|
||||
continue;
|
||||
} else if (optimize_depend && input_idx > 1) {
|
||||
} else if ((is_depend && input_idx > 1) || (is_updatestate && input_idx > 2)) {
|
||||
cnode_inputs->push_back(NewValueNode(MakeValue(SizeToInt(input_idx))));
|
||||
continue;
|
||||
} else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) {
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <functional>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <unordered_set>
|
||||
#include <tuple>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
@ -39,43 +40,39 @@ namespace compile {
|
|||
ConvertCache g_ConvertCache;
|
||||
void ClearConvertCache() { g_ConvertCache.clear(); }
|
||||
|
||||
namespace {
|
||||
// Return the list of nodes whose values are required beyond this segment.
|
||||
// Arguments:
|
||||
// lst: list of nodes (the segment)
|
||||
// nodes: list of nodes in the segment
|
||||
// users: dict mapping each node to its users (globally)
|
||||
// seen: set of nodes that are part of the segment
|
||||
AnfNodePtrList GetOutput(const AnfNodePtrList &lst, const NodeUsersMap &users, const std::vector<AnfNodePtr> &seen) {
|
||||
AnfNodePtrList GetOutput(const AnfNodePtrList &nodes, const NodeUsersMap &users,
|
||||
const std::unordered_set<AnfNodePtr> &seen) {
|
||||
AnfNodePtrList output;
|
||||
if (users.size() == 0) {
|
||||
return output;
|
||||
}
|
||||
|
||||
(void)std::transform(
|
||||
std::begin(lst), std::end(lst), std::back_inserter(output), [&users, &seen](AnfNodePtr n) -> AnfNodePtr {
|
||||
auto usersn = users.find(n);
|
||||
bool is_referred_out_of_segment = std::any_of(
|
||||
std::begin(usersn->second), std::end(usersn->second), [&seen](const std::pair<AnfNodePtr, int64_t> &u) -> bool {
|
||||
return std::find(std::begin(seen), std::end(seen), u.first) == std::end(seen);
|
||||
});
|
||||
if (n->isa<CNode>() && is_referred_out_of_segment) {
|
||||
return n;
|
||||
}
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
// remove nullptr
|
||||
for (auto it = output.begin(); it != output.end();) {
|
||||
if (*it == nullptr) {
|
||||
it = output.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
for (auto &node : nodes) {
|
||||
if (!node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
auto iter = users.find(node);
|
||||
if (iter == users.end()) {
|
||||
continue;
|
||||
}
|
||||
auto &node_users = iter->second;
|
||||
const bool has_outer_user = std::any_of(
|
||||
std::begin(node_users), std::end(node_users), [&seen](const std::pair<AnfNodePtr, int64_t> &u) -> bool {
|
||||
const bool is_outer_user = (seen.find(u.first) == seen.end());
|
||||
return is_outer_user && !(IsPrimitiveCNode(u.first, prim::kPrimUpdateState) && u.second > 2);
|
||||
});
|
||||
if (has_outer_user) {
|
||||
output.emplace_back(node);
|
||||
}
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
namespace {
|
||||
AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNodePtrList *const inputs_ptr,
|
||||
AnfNodePtrToAnfNodePtrMap *eqv_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
|
@ -129,6 +126,15 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr
|
|||
for (size_t i = 2; i < inps.size(); ++i) {
|
||||
args.emplace_back(NewValueNode(MakeValue(0)));
|
||||
}
|
||||
} else if (IsPrimitive(fn, prim::kPrimUpdateState)) {
|
||||
args.emplace_back(RefSubGraphNode(fg, inps[1], &inputs, &eqv));
|
||||
args.emplace_back(RefSubGraphNode(fg, inps[2], &inputs, &eqv));
|
||||
for (size_t i = 3; i < inps.size(); ++i) {
|
||||
auto &input = inps[i];
|
||||
if (eqv.find(input) != eqv.end()) {
|
||||
args.emplace_back(RefSubGraphNode(fg, input, &inputs, &eqv));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
(void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args),
|
||||
[&fg, &inputs, &eqv](const AnfNodePtr &a) { return RefSubGraphNode(fg, a, &inputs, &eqv); });
|
||||
|
@ -138,8 +144,8 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr
|
|||
eqv[n]->set_abstract(n->abstract());
|
||||
eqv[n]->set_kernel_info(n->kernel_info_ptr());
|
||||
}
|
||||
std::vector<AnfNodePtr> eqv_keys;
|
||||
(void)std::transform(std::begin(eqv), std::end(eqv), std::back_inserter(eqv_keys),
|
||||
std::unordered_set<AnfNodePtr> eqv_keys;
|
||||
(void)std::transform(std::begin(eqv), std::end(eqv), std::inserter(eqv_keys, eqv_keys.end()),
|
||||
[](const std::pair<AnfNodePtr, AnfNodePtr> &elem) -> AnfNodePtr { return elem.first; });
|
||||
auto outputs = GetOutput(lst, lst[0]->func_graph()->manager()->node_users(), eqv_keys);
|
||||
AnfNodePtr fg_output;
|
||||
|
|
|
@ -465,7 +465,7 @@ std::string GetCNodeTarget(const AnfNodePtr &node) {
|
|||
}
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
|
||||
auto &inputs = cnode->inputs();
|
||||
if (inputs.size() == 3 && !IsPrimitiveCNode(inputs[2], prim::kPrimMakeTuple)) {
|
||||
if (inputs.size() >= 3 && !IsPrimitiveCNode(inputs[2], prim::kPrimMakeTuple)) {
|
||||
return GetCNodeTarget(inputs[2]);
|
||||
}
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
|
||||
|
|
Loading…
Reference in New Issue