Fix segment input/output caculation

This commit is contained in:
He Wei 2021-05-07 19:35:32 +08:00
parent 1dc0efbab5
commit 98bab1eb87
3 changed files with 36 additions and 29 deletions

View File

@ -727,7 +727,8 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph,
MS_EXCEPTION_IF_NULL(other_graph_cnode);
MS_EXCEPTION_IF_NULL(cnode_inputs);
auto origin_inputs = cnode->inputs();
bool optimize_depend = IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() >= 3;
const bool is_depend = IsPrimitiveCNode(cnode, prim::kPrimDepend);
const bool is_updatestate = IsPrimitiveCNode(cnode, prim::kPrimUpdateState);
// if has multiple depends,only select first depend as parameter
for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) {
auto anf = origin_inputs[input_idx];
@ -736,7 +737,7 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph,
if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
(void)cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(anf));
continue;
} else if (optimize_depend && input_idx > 1) {
} else if ((is_depend && input_idx > 1) || (is_updatestate && input_idx > 2)) {
cnode_inputs->push_back(NewValueNode(MakeValue(SizeToInt(input_idx))));
continue;
} else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) {

View File

@ -22,6 +22,7 @@
#include <functional>
#include <memory>
#include <set>
#include <unordered_set>
#include <tuple>
#include <unordered_map>
#include <utility>
@ -39,43 +40,39 @@ namespace compile {
ConvertCache g_ConvertCache;
void ClearConvertCache() { g_ConvertCache.clear(); }
namespace {
// Return the list of nodes whose values are required beyond this segment.
// Arguments:
// lst: list of nodes (the segment)
// nodes: list of nodes in the segment
// users: dict mapping each node to its users (globally)
// seen: set of nodes that are part of the segment
AnfNodePtrList GetOutput(const AnfNodePtrList &lst, const NodeUsersMap &users, const std::vector<AnfNodePtr> &seen) {
AnfNodePtrList GetOutput(const AnfNodePtrList &nodes, const NodeUsersMap &users,
const std::unordered_set<AnfNodePtr> &seen) {
AnfNodePtrList output;
if (users.size() == 0) {
return output;
}
(void)std::transform(
std::begin(lst), std::end(lst), std::back_inserter(output), [&users, &seen](AnfNodePtr n) -> AnfNodePtr {
auto usersn = users.find(n);
bool is_referred_out_of_segment = std::any_of(
std::begin(usersn->second), std::end(usersn->second), [&seen](const std::pair<AnfNodePtr, int64_t> &u) -> bool {
return std::find(std::begin(seen), std::end(seen), u.first) == std::end(seen);
});
if (n->isa<CNode>() && is_referred_out_of_segment) {
return n;
}
return nullptr;
});
// remove nullptr
for (auto it = output.begin(); it != output.end();) {
if (*it == nullptr) {
it = output.erase(it);
} else {
++it;
for (auto &node : nodes) {
if (!node->isa<CNode>()) {
continue;
}
auto iter = users.find(node);
if (iter == users.end()) {
continue;
}
auto &node_users = iter->second;
const bool has_outer_user = std::any_of(
std::begin(node_users), std::end(node_users), [&seen](const std::pair<AnfNodePtr, int64_t> &u) -> bool {
const bool is_outer_user = (seen.find(u.first) == seen.end());
return is_outer_user && !(IsPrimitiveCNode(u.first, prim::kPrimUpdateState) && u.second > 2);
});
if (has_outer_user) {
output.emplace_back(node);
}
}
return output;
}
namespace {
AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNodePtrList *const inputs_ptr,
AnfNodePtrToAnfNodePtrMap *eqv_ptr) {
MS_EXCEPTION_IF_NULL(fg);
@ -129,6 +126,15 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr
for (size_t i = 2; i < inps.size(); ++i) {
args.emplace_back(NewValueNode(MakeValue(0)));
}
} else if (IsPrimitive(fn, prim::kPrimUpdateState)) {
args.emplace_back(RefSubGraphNode(fg, inps[1], &inputs, &eqv));
args.emplace_back(RefSubGraphNode(fg, inps[2], &inputs, &eqv));
for (size_t i = 3; i < inps.size(); ++i) {
auto &input = inps[i];
if (eqv.find(input) != eqv.end()) {
args.emplace_back(RefSubGraphNode(fg, input, &inputs, &eqv));
}
}
} else {
(void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args),
[&fg, &inputs, &eqv](const AnfNodePtr &a) { return RefSubGraphNode(fg, a, &inputs, &eqv); });
@ -138,8 +144,8 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr
eqv[n]->set_abstract(n->abstract());
eqv[n]->set_kernel_info(n->kernel_info_ptr());
}
std::vector<AnfNodePtr> eqv_keys;
(void)std::transform(std::begin(eqv), std::end(eqv), std::back_inserter(eqv_keys),
std::unordered_set<AnfNodePtr> eqv_keys;
(void)std::transform(std::begin(eqv), std::end(eqv), std::inserter(eqv_keys, eqv_keys.end()),
[](const std::pair<AnfNodePtr, AnfNodePtr> &elem) -> AnfNodePtr { return elem.first; });
auto outputs = GetOutput(lst, lst[0]->func_graph()->manager()->node_users(), eqv_keys);
AnfNodePtr fg_output;

View File

@ -465,7 +465,7 @@ std::string GetCNodeTarget(const AnfNodePtr &node) {
}
} else if (IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
auto &inputs = cnode->inputs();
if (inputs.size() == 3 && !IsPrimitiveCNode(inputs[2], prim::kPrimMakeTuple)) {
if (inputs.size() >= 3 && !IsPrimitiveCNode(inputs[2], prim::kPrimMakeTuple)) {
return GetCNodeTarget(inputs[2]);
}
} else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {