!20721 optimize partial order

Merge pull request !20721 from kisnwang/optimize-partial-order
This commit is contained in:
i-robot 2021-07-28 01:19:54 +00:00 committed by Gitee
commit a6c3dadb96
2 changed files with 40 additions and 11 deletions

View File

@ -136,6 +136,39 @@ std::vector<AnfNodePtr> ReorderVirtualNode(const std::vector<AnfNodePtr> &nodes,
return result;
}
std::vector<AnfNodePtr> GetNextNodes(const AnfNodePtr &node, std::map<AnfNodePtr, size_t> *nodes_ref,
std::vector<AnfNodePtr> *result) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(nodes_ref);
MS_EXCEPTION_IF_NULL(result);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto node_inputs = cnode->inputs();
if (!IsPrimitiveCNode(node, prim::kPrimSwitch)) {
std::reverse(node_inputs.begin(), node_inputs.end());
return node_inputs;
}
std::vector<AnfNodePtr> extend_inputs;
for (auto &input : node_inputs) {
MS_EXCEPTION_IF_NULL(input);
if (IsPrimitiveCNode(input, prim::kPrimPartial)) {
auto iter = nodes_ref->find(input);
if (iter != nodes_ref->end() && iter->second == 1) {
iter->second--;
result->emplace_back(input);
auto partial_cnode = input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(partial_cnode);
auto partial_inputs = partial_cnode->inputs();
std::reverse(partial_inputs.begin(), partial_inputs.end());
(void)extend_inputs.insert(extend_inputs.end(), partial_inputs.begin(), partial_inputs.end());
continue;
}
}
extend_inputs.emplace_back(input);
}
return extend_inputs;
}
std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &default_target) {
MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> result;
@ -158,13 +191,8 @@ std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &
if (!node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto node_inputs = cnode->inputs();
if (!IsPrimitiveCNode(cnode, prim::kPrimSwitch)) {
std::reverse(node_inputs.begin(), node_inputs.end());
}
for (auto &input : node_inputs) {
auto next_nodes = GetNextNodes(node, &nodes_ref, &result);
for (auto &input : next_nodes) {
MS_EXCEPTION_IF_NULL(input);
auto iter = nodes_ref.find(input);
if (iter != nodes_ref.end()) {
@ -621,8 +649,9 @@ std::vector<GraphSegmentPtr> GraphPartition::Partition(const FuncGraphPtr &graph
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
auto enable_loop_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK);
std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (contain_multi_target) {
if (contain_multi_target || !enable_loop_sink) {
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_PARALLEL_SPLIT)) {
auto other_target = GetOtherTarget(nodes);
nodes = ParallelSort(graph, default_target, other_target);

View File

@ -225,7 +225,7 @@ bool CompileGraph::Compile(const FuncGraphPtr &graph) {
int64_t ret = RET_SUCCESS;
if (!segment->is_cut_) {
MS_LOG(DEBUG) << "Start a extern LinConvert";
if (segment->nodes_.size() > 0) {
if (!segment->nodes_.empty()) {
std::string cur_target = GetCNodeTarget(segment->nodes_[0]);
ret = LinConvert(graph, segment, cur_target);
} else {
@ -238,14 +238,14 @@ bool CompileGraph::Compile(const FuncGraphPtr &graph) {
if (ret == RET_CONTINUE) {
continue;
}
} else {
} else if (!segment->nodes_.empty()) {
MS_LOG(DEBUG) << "Start a cut node";
auto &cut_node = segment->nodes_[0];
MS_EXCEPTION_IF_NULL(cut_node);
if (!cut_node->isa<CNode>()) {
MS_LOG(EXCEPTION) << "must be anfnode here NodeInfo: " << trace::GetDebugInfo(graph->debug_info());
}
CNodePtr node = cut_node->cast<CNodePtr>();
auto node = cut_node->cast<CNodePtr>();
ret = InterpretNode(graph, node);
MS_LOG(DEBUG) << "End a cut node";
if (ret == RET_BREAK) {