forked from mindspore-Ecosystem/mindspore
!20721 optimize partial order
Merge pull request !20721 from kisnwang/optimize-partial-order
This commit is contained in:
commit
a6c3dadb96
|
@ -136,6 +136,39 @@ std::vector<AnfNodePtr> ReorderVirtualNode(const std::vector<AnfNodePtr> &nodes,
|
|||
return result;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> GetNextNodes(const AnfNodePtr &node, std::map<AnfNodePtr, size_t> *nodes_ref,
|
||||
std::vector<AnfNodePtr> *result) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(nodes_ref);
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto node_inputs = cnode->inputs();
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimSwitch)) {
|
||||
std::reverse(node_inputs.begin(), node_inputs.end());
|
||||
return node_inputs;
|
||||
}
|
||||
std::vector<AnfNodePtr> extend_inputs;
|
||||
for (auto &input : node_inputs) {
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
if (IsPrimitiveCNode(input, prim::kPrimPartial)) {
|
||||
auto iter = nodes_ref->find(input);
|
||||
if (iter != nodes_ref->end() && iter->second == 1) {
|
||||
iter->second--;
|
||||
result->emplace_back(input);
|
||||
auto partial_cnode = input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(partial_cnode);
|
||||
auto partial_inputs = partial_cnode->inputs();
|
||||
std::reverse(partial_inputs.begin(), partial_inputs.end());
|
||||
(void)extend_inputs.insert(extend_inputs.end(), partial_inputs.begin(), partial_inputs.end());
|
||||
continue;
|
||||
}
|
||||
}
|
||||
extend_inputs.emplace_back(input);
|
||||
}
|
||||
return extend_inputs;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &default_target) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
std::vector<AnfNodePtr> result;
|
||||
|
@ -158,13 +191,8 @@ std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &
|
|||
if (!node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto node_inputs = cnode->inputs();
|
||||
if (!IsPrimitiveCNode(cnode, prim::kPrimSwitch)) {
|
||||
std::reverse(node_inputs.begin(), node_inputs.end());
|
||||
}
|
||||
for (auto &input : node_inputs) {
|
||||
auto next_nodes = GetNextNodes(node, &nodes_ref, &result);
|
||||
for (auto &input : next_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
auto iter = nodes_ref.find(input);
|
||||
if (iter != nodes_ref.end()) {
|
||||
|
@ -621,8 +649,9 @@ std::vector<GraphSegmentPtr> GraphPartition::Partition(const FuncGraphPtr &graph
|
|||
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
auto enable_loop_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK);
|
||||
std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
if (contain_multi_target) {
|
||||
if (contain_multi_target || !enable_loop_sink) {
|
||||
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_PARALLEL_SPLIT)) {
|
||||
auto other_target = GetOtherTarget(nodes);
|
||||
nodes = ParallelSort(graph, default_target, other_target);
|
||||
|
|
|
@ -225,7 +225,7 @@ bool CompileGraph::Compile(const FuncGraphPtr &graph) {
|
|||
int64_t ret = RET_SUCCESS;
|
||||
if (!segment->is_cut_) {
|
||||
MS_LOG(DEBUG) << "Start a extern LinConvert";
|
||||
if (segment->nodes_.size() > 0) {
|
||||
if (!segment->nodes_.empty()) {
|
||||
std::string cur_target = GetCNodeTarget(segment->nodes_[0]);
|
||||
ret = LinConvert(graph, segment, cur_target);
|
||||
} else {
|
||||
|
@ -238,14 +238,14 @@ bool CompileGraph::Compile(const FuncGraphPtr &graph) {
|
|||
if (ret == RET_CONTINUE) {
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
} else if (!segment->nodes_.empty()) {
|
||||
MS_LOG(DEBUG) << "Start a cut node";
|
||||
auto &cut_node = segment->nodes_[0];
|
||||
MS_EXCEPTION_IF_NULL(cut_node);
|
||||
if (!cut_node->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "must be anfnode here NodeInfo: " << trace::GetDebugInfo(graph->debug_info());
|
||||
}
|
||||
CNodePtr node = cut_node->cast<CNodePtr>();
|
||||
auto node = cut_node->cast<CNodePtr>();
|
||||
ret = InterpretNode(graph, node);
|
||||
MS_LOG(DEBUG) << "End a cut node";
|
||||
if (ret == RET_BREAK) {
|
||||
|
|
Loading…
Reference in New Issue