forked from mindspore-Ecosystem/mindspore
reorder virtual node for heter exe
This commit is contained in:
parent
a815ad72c1
commit
0a3cd37002
|
@ -82,28 +82,40 @@ void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *n
|
|||
}
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> OptimizeGetItemOrder(const std::vector<AnfNodePtr> &nodes) {
|
||||
std::vector<AnfNodePtr> ReorderVirtualNode(const std::vector<AnfNodePtr> &nodes, const PrimitivePtr &reorder_prim) {
|
||||
std::vector<AnfNodePtr> result;
|
||||
std::map<size_t, std::vector<AnfNodePtr>> insert_positions;
|
||||
std::map<AnfNodePtr, size_t> node_positions;
|
||||
auto add_insert_position = [&insert_positions, &node_positions](const AnfNodePtr &node, const AnfNodePtr &parent) {
|
||||
if (parent == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto iter = node_positions.find(parent);
|
||||
if (iter != node_positions.end()) {
|
||||
size_t position = iter->second;
|
||||
auto iter_nodes = insert_positions.find(position);
|
||||
if (iter_nodes != insert_positions.end()) {
|
||||
iter_nodes->second.push_back(node);
|
||||
} else {
|
||||
(void)insert_positions.emplace(position, std::vector<AnfNodePtr>{node});
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
for (auto &node : nodes) {
|
||||
if (node->isa<CNode>() && IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
|
||||
if (node->isa<CNode>() && IsPrimitiveCNode(node, reorder_prim)) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto &inputs = cnode->inputs();
|
||||
if (inputs.size() <= 1) {
|
||||
MS_LOG(EXCEPTION) << "Invalid get item node";
|
||||
AnfNodePtr parent = nullptr;
|
||||
const size_t depend_input_size = 2;
|
||||
if (reorder_prim == prim::kPrimDepend && inputs.size() == depend_input_size + 1 && !inputs[1]->isa<CNode>()) {
|
||||
parent = inputs[depend_input_size];
|
||||
} else if (reorder_prim == prim::kPrimTupleGetItem && inputs.size() > 1) {
|
||||
parent = inputs[1];
|
||||
}
|
||||
auto &parent = inputs[1];
|
||||
auto iter = node_positions.find(parent);
|
||||
if (iter != node_positions.end()) {
|
||||
size_t position = iter->second;
|
||||
auto iter_nodes = insert_positions.find(position);
|
||||
if (iter_nodes != insert_positions.end()) {
|
||||
iter_nodes->second.push_back(node);
|
||||
} else {
|
||||
(void)insert_positions.emplace(position, std::vector<AnfNodePtr>{node});
|
||||
}
|
||||
if (add_insert_position(node, parent)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
@ -607,7 +619,8 @@ std::vector<GraphSegmentPtr> GraphPartition::Partition(const FuncGraphPtr &graph
|
|||
} else {
|
||||
nodes = SplitSort(graph, default_target);
|
||||
}
|
||||
nodes = OptimizeGetItemOrder(nodes);
|
||||
nodes = ReorderVirtualNode(nodes, prim::kPrimTupleGetItem);
|
||||
nodes = ReorderVirtualNode(nodes, prim::kPrimDepend);
|
||||
}
|
||||
std::vector<GraphSegmentPtr> segments;
|
||||
std::vector<AnfNodePtr> segment_nodes;
|
||||
|
|
Loading…
Reference in New Issue