!20721 optimize partial order

Merge pull request !20721 from kisnwang/optimize-partial-order
This commit is contained in:
i-robot 2021-07-28 01:19:54 +00:00 committed by Gitee
commit a6c3dadb96
2 changed files with 40 additions and 11 deletions

View File

@ -136,6 +136,39 @@ std::vector<AnfNodePtr> ReorderVirtualNode(const std::vector<AnfNodePtr> &nodes,
return result; return result;
} }
std::vector<AnfNodePtr> GetNextNodes(const AnfNodePtr &node, std::map<AnfNodePtr, size_t> *nodes_ref,
std::vector<AnfNodePtr> *result) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(nodes_ref);
MS_EXCEPTION_IF_NULL(result);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto node_inputs = cnode->inputs();
if (!IsPrimitiveCNode(node, prim::kPrimSwitch)) {
std::reverse(node_inputs.begin(), node_inputs.end());
return node_inputs;
}
std::vector<AnfNodePtr> extend_inputs;
for (auto &input : node_inputs) {
MS_EXCEPTION_IF_NULL(input);
if (IsPrimitiveCNode(input, prim::kPrimPartial)) {
auto iter = nodes_ref->find(input);
if (iter != nodes_ref->end() && iter->second == 1) {
iter->second--;
result->emplace_back(input);
auto partial_cnode = input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(partial_cnode);
auto partial_inputs = partial_cnode->inputs();
std::reverse(partial_inputs.begin(), partial_inputs.end());
(void)extend_inputs.insert(extend_inputs.end(), partial_inputs.begin(), partial_inputs.end());
continue;
}
}
extend_inputs.emplace_back(input);
}
return extend_inputs;
}
std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &default_target) { std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &default_target) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> result; std::vector<AnfNodePtr> result;
@ -158,13 +191,8 @@ std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &
if (!node->isa<CNode>()) { if (!node->isa<CNode>()) {
continue; continue;
} }
auto cnode = node->cast<CNodePtr>(); auto next_nodes = GetNextNodes(node, &nodes_ref, &result);
MS_EXCEPTION_IF_NULL(cnode); for (auto &input : next_nodes) {
auto node_inputs = cnode->inputs();
if (!IsPrimitiveCNode(cnode, prim::kPrimSwitch)) {
std::reverse(node_inputs.begin(), node_inputs.end());
}
for (auto &input : node_inputs) {
MS_EXCEPTION_IF_NULL(input); MS_EXCEPTION_IF_NULL(input);
auto iter = nodes_ref.find(input); auto iter = nodes_ref.find(input);
if (iter != nodes_ref.end()) { if (iter != nodes_ref.end()) {
@ -621,8 +649,9 @@ std::vector<GraphSegmentPtr> GraphPartition::Partition(const FuncGraphPtr &graph
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
auto enable_loop_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK);
std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET); std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (contain_multi_target) { if (contain_multi_target || !enable_loop_sink) {
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_PARALLEL_SPLIT)) { if (context_ptr->get_param<bool>(MS_CTX_ENABLE_PARALLEL_SPLIT)) {
auto other_target = GetOtherTarget(nodes); auto other_target = GetOtherTarget(nodes);
nodes = ParallelSort(graph, default_target, other_target); nodes = ParallelSort(graph, default_target, other_target);

View File

@ -225,7 +225,7 @@ bool CompileGraph::Compile(const FuncGraphPtr &graph) {
int64_t ret = RET_SUCCESS; int64_t ret = RET_SUCCESS;
if (!segment->is_cut_) { if (!segment->is_cut_) {
MS_LOG(DEBUG) << "Start a extern LinConvert"; MS_LOG(DEBUG) << "Start a extern LinConvert";
if (segment->nodes_.size() > 0) { if (!segment->nodes_.empty()) {
std::string cur_target = GetCNodeTarget(segment->nodes_[0]); std::string cur_target = GetCNodeTarget(segment->nodes_[0]);
ret = LinConvert(graph, segment, cur_target); ret = LinConvert(graph, segment, cur_target);
} else { } else {
@ -238,14 +238,14 @@ bool CompileGraph::Compile(const FuncGraphPtr &graph) {
if (ret == RET_CONTINUE) { if (ret == RET_CONTINUE) {
continue; continue;
} }
} else { } else if (!segment->nodes_.empty()) {
MS_LOG(DEBUG) << "Start a cut node"; MS_LOG(DEBUG) << "Start a cut node";
auto &cut_node = segment->nodes_[0]; auto &cut_node = segment->nodes_[0];
MS_EXCEPTION_IF_NULL(cut_node); MS_EXCEPTION_IF_NULL(cut_node);
if (!cut_node->isa<CNode>()) { if (!cut_node->isa<CNode>()) {
MS_LOG(EXCEPTION) << "must be anfnode here NodeInfo: " << trace::GetDebugInfo(graph->debug_info()); MS_LOG(EXCEPTION) << "must be anfnode here NodeInfo: " << trace::GetDebugInfo(graph->debug_info());
} }
CNodePtr node = cut_node->cast<CNodePtr>(); auto node = cut_node->cast<CNodePtr>();
ret = InterpretNode(graph, node); ret = InterpretNode(graph, node);
MS_LOG(DEBUG) << "End a cut node"; MS_LOG(DEBUG) << "End a cut node";
if (ret == RET_BREAK) { if (ret == RET_BREAK) {