reorder virtual node for heter exe

This commit is contained in:
kswang 2021-07-05 22:16:46 +08:00
parent a815ad72c1
commit 0a3cd37002
1 changed files with 28 additions and 15 deletions

View File

@ -82,28 +82,40 @@ void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *n
}
}
std::vector<AnfNodePtr> OptimizeGetItemOrder(const std::vector<AnfNodePtr> &nodes) {
std::vector<AnfNodePtr> ReorderVirtualNode(const std::vector<AnfNodePtr> &nodes, const PrimitivePtr &reorder_prim) {
std::vector<AnfNodePtr> result;
std::map<size_t, std::vector<AnfNodePtr>> insert_positions;
std::map<AnfNodePtr, size_t> node_positions;
auto add_insert_position = [&insert_positions, &node_positions](const AnfNodePtr &node, const AnfNodePtr &parent) {
if (parent == nullptr) {
return false;
}
auto iter = node_positions.find(parent);
if (iter != node_positions.end()) {
size_t position = iter->second;
auto iter_nodes = insert_positions.find(position);
if (iter_nodes != insert_positions.end()) {
iter_nodes->second.push_back(node);
} else {
(void)insert_positions.emplace(position, std::vector<AnfNodePtr>{node});
}
return true;
}
return false;
};
for (auto &node : nodes) {
if (node->isa<CNode>() && IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
if (node->isa<CNode>() && IsPrimitiveCNode(node, reorder_prim)) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto &inputs = cnode->inputs();
if (inputs.size() <= 1) {
MS_LOG(EXCEPTION) << "Invalid get item node";
AnfNodePtr parent = nullptr;
const size_t depend_input_size = 2;
if (reorder_prim == prim::kPrimDepend && inputs.size() == depend_input_size + 1 && !inputs[1]->isa<CNode>()) {
parent = inputs[depend_input_size];
} else if (reorder_prim == prim::kPrimTupleGetItem && inputs.size() > 1) {
parent = inputs[1];
}
auto &parent = inputs[1];
auto iter = node_positions.find(parent);
if (iter != node_positions.end()) {
size_t position = iter->second;
auto iter_nodes = insert_positions.find(position);
if (iter_nodes != insert_positions.end()) {
iter_nodes->second.push_back(node);
} else {
(void)insert_positions.emplace(position, std::vector<AnfNodePtr>{node});
}
if (add_insert_position(node, parent)) {
continue;
}
}
@ -607,7 +619,8 @@ std::vector<GraphSegmentPtr> GraphPartition::Partition(const FuncGraphPtr &graph
} else {
nodes = SplitSort(graph, default_target);
}
nodes = OptimizeGetItemOrder(nodes);
nodes = ReorderVirtualNode(nodes, prim::kPrimTupleGetItem);
nodes = ReorderVirtualNode(nodes, prim::kPrimDepend);
}
std::vector<GraphSegmentPtr> segments;
std::vector<AnfNodePtr> segment_nodes;