!10907 execution reorder to overide trailing time of the last allreduce with optimizer

From: @shibeiji
Reviewed-by: @kisnwang
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-01-06 10:11:14 +08:00 committed by Gitee
commit b5313fcc05
2 changed files with 308 additions and 0 deletions

View File

@ -25,6 +25,7 @@
#include "backend/session/anf_runtime_algorithm.h"
#include "runtime/device/kernel_adjust.h"
#include "backend/optimizer/common/helper.h"
#include "backend/kernel_compiler/oplib/oplib.h"
#include "utils/utils.h"
namespace mindspore {
@ -38,6 +39,7 @@ void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr)
Reset();
SetLoopSink();
ReorderIndependentOrders(graph_ptr);
TrailingTimeOptimizationByReorder(graph_ptr);
AssignAllNodesStream(graph_ptr);
UpdateAtomicAddrCleanStreamId(graph_ptr);
@ -128,6 +130,305 @@ void AscendStreamAssign::ReorderIndependentOrders(const NotNull<KernelGraphPtr>
graph_ptr->set_execution_order(exe_orders);
}
void AscendStreamAssign::CheckScenario(const NotNull<KernelGraphPtr> &graph_ptr,
vector<CNodePtr> *last_grad_and_status) {
auto cnode_ptr_list = graph_ptr->execution_order();
vector<CNodePtr> hcom_nodes;
CNodePtr cur_cnode_ptr = nullptr;
CNodePtr overflow_marker = nullptr;
std::string kNPUGetFloatStatusOpName = "NPUGetFloatStatus";
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
cur_cnode_ptr = cnode_ptr_list[i];
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kNPUGetFloatStatusOpName) {
overflow_marker = cur_cnode_ptr;
} else if (IsHcom(cur_cnode_ptr)) {
hcom_nodes.emplace_back(cur_cnode_ptr);
} else if (i > 0 && AnfAlgo::GetCNodeName(cnode_ptr_list[i - 1]) == kAtomicAddrCleanOpName) {
auto graph_id = AnfAlgo::GetGraphId(cur_cnode_ptr.get());
AnfAlgo::SetGraphId(graph_id, cnode_ptr_list[i - 1].get());
}
}
if (hcom_nodes.size() < 2 || overflow_marker == nullptr) {
MS_LOG(INFO) << "Current model isn't in distribute or mix-precision mode, no optimization needed";
last_grad_and_status->clear();
return;
}
auto overflow_marker_pos = find(cnode_ptr_list.begin(), cnode_ptr_list.end(), overflow_marker);
auto last_hcom_ptr = hcom_nodes[hcom_nodes.size() - 1];
auto last_hcom_pos = find(cnode_ptr_list.begin(), cnode_ptr_list.end(), last_hcom_ptr);
auto last_grad_hcom_ptr = hcom_nodes[hcom_nodes.size() - 2];
auto last_grad_hcom_pos = find(cnode_ptr_list.begin(), cnode_ptr_list.end(), last_grad_hcom_ptr);
if (last_grad_hcom_pos > overflow_marker_pos || last_hcom_pos < overflow_marker_pos) {
MS_LOG(INFO) << "Grads average done after overflow judgement or status aren't allgathered, no optimization needed";
last_grad_and_status->clear();
return;
}
auto last_inputs = GetLastInputCnode(graph_ptr, last_grad_hcom_ptr);
if (last_inputs.empty() || last_inputs.size() > 1 || IsHcom(last_inputs[0])) {
MS_LOG(INFO) << "Inputs of last gradients allreduce is empty or include other allreduce, no optimization needed";
last_grad_and_status->clear();
return;
}
auto last_grad_ptr = last_inputs[0];
MS_LOG(DEBUG) << "Last Hcom: " << last_grad_hcom_ptr->fullname_with_scope()
<< "; last input: " << last_grad_ptr->fullname_with_scope();
auto last_grad_hcom_graph_id = AnfAlgo::GetGraphId(last_grad_hcom_ptr.get());
auto last_grad_graph_id = AnfAlgo::GetGraphId(last_grad_ptr.get());
auto overflow_marker_graph_id = AnfAlgo::GetGraphId(overflow_marker.get());
if (last_grad_graph_id != last_grad_hcom_graph_id || last_grad_graph_id != overflow_marker_graph_id) {
MS_LOG(INFO) << "The grads and grad_hcom or overflow marker were not on the same subgraph, no optimization needed";
last_grad_and_status->clear();
return;
}
auto label_switch_pos = find_if(last_grad_hcom_pos, cnode_ptr_list.end(),
[](CNodePtr &node) -> bool { return AnfAlgo::GetCNodeName(node) == "LabelSwitch"; });
if (label_switch_pos == cnode_ptr_list.end()) {
MS_LOG(INFO) << "No branches after getting overflow status, no optimization needed";
last_grad_and_status->clear();
return;
}
last_grad_and_status->emplace_back(last_grad_ptr);
last_grad_and_status->emplace_back(overflow_marker);
return;
}
CNodePtr AscendStreamAssign::GetCNodesNeededMoved(vector<CNodePtr> *moved_backward_cnodes,
vector<CNodePtr> *moved_forward_cnodes,
const vector<CNodePtr> &last_grad_and_status,
const NotNull<KernelGraphPtr> &graph_ptr) {
auto cnode_ptr_list = graph_ptr->execution_order();
if (last_grad_and_status.size() != 2) {
return nullptr;
}
auto last_grad_ptr = last_grad_and_status[0];
auto float_status_ptr = last_grad_and_status[1];
auto last_grad_pos = find(cnode_ptr_list.begin(), cnode_ptr_list.end(), last_grad_ptr);
auto float_status_pos = find(cnode_ptr_list.begin(), cnode_ptr_list.end(), float_status_ptr);
if (last_grad_pos == cnode_ptr_list.end() || float_status_pos == cnode_ptr_list.end()) {
moved_backward_cnodes->clear();
moved_forward_cnodes->clear();
return nullptr;
}
auto graph_id = AnfAlgo::GetGraphId(last_grad_ptr.get());
moved_backward_cnodes->insert(moved_backward_cnodes->end(), last_grad_pos + 1, float_status_pos);
auto it = float_status_pos;
while (AnfAlgo::GetGraphId((*it).get()) == graph_id && it < cnode_ptr_list.end()) {
if (AnfAlgo::GetCNodeName(*it) == kAtomicAddrCleanOpName) {
it++;
continue;
}
auto inputs = GetInputKernels(*it);
bool is_independent = true;
for (auto &input : inputs) {
if (find(moved_backward_cnodes->begin(), moved_backward_cnodes->end(), input) != moved_backward_cnodes->end()) {
is_independent = false;
break;
}
}
if (is_independent) {
if (AnfAlgo::GetCNodeName(*(it - 1)) == kAtomicAddrCleanOpName) {
moved_forward_cnodes->emplace_back(*(it - 1));
}
moved_forward_cnodes->emplace_back(*it);
} else {
if (AnfAlgo::GetCNodeName(*(it - 1)) == kAtomicAddrCleanOpName) {
moved_backward_cnodes->emplace_back(*(it - 1));
}
moved_backward_cnodes->emplace_back(*it);
}
it++;
}
// check ref nodes
for (auto &cnode : *moved_backward_cnodes) {
std::string op_name = AnfAlgo::GetCNodeName(cnode);
auto op_info = mindspore::kernel::OpLib::FindOp(op_name, kernel::kTBE);
if (op_info != nullptr && op_info->is_ref()) {
MS_LOG(INFO) << "Find RefNode: " << op_name << ", full name: " << cnode->fullname_with_scope();
moved_backward_cnodes->clear();
moved_forward_cnodes->clear();
return nullptr;
}
}
size_t total_moved_size = it - last_grad_pos - 1;
if (moved_backward_cnodes->size() + moved_forward_cnodes->size() != total_moved_size) {
MS_LOG(DEBUG) << "Total number inconsistent, total cnode number: " << total_moved_size
<< ", while move forward size: " << moved_forward_cnodes->size()
<< ", moved backward size: " << moved_backward_cnodes->size();
moved_forward_cnodes->clear();
moved_backward_cnodes->clear();
return nullptr;
}
uint32_t subgraph_id = 0;
bool get_subgraph_id = false;
CNodePtr first_output_node_ptr = nullptr;
while (!get_subgraph_id && it < cnode_ptr_list.end()) {
auto inputs = GetInputKernels(*it);
for (auto &input : inputs) {
if (find(moved_backward_cnodes->begin(), moved_backward_cnodes->end(), input) != moved_backward_cnodes->end()) {
MS_LOG(DEBUG) << "get subgraph id: " << AnfAlgo::GetGraphId((*it).get());
get_subgraph_id = true;
subgraph_id = AnfAlgo::GetGraphId((*it).get());
first_output_node_ptr = *it;
break;
}
}
it++;
}
if (subgraph_id == 0) {
MS_LOG(INFO) << "The nodes moved backward were not used by any other nodes, no need moved";
moved_forward_cnodes->clear();
moved_backward_cnodes->clear();
return nullptr;
}
for (; it < cnode_ptr_list.end() && AnfAlgo::GetGraphId((*it).get()) != subgraph_id; it++) {
auto inputs = GetInputKernels(*it);
for (auto &input : inputs) {
if (find(moved_backward_cnodes->begin(), moved_backward_cnodes->end(), input) != moved_backward_cnodes->end()) {
MS_LOG(INFO) << "The nodes moved backward were used by nodes on different subgraphs, no need moved";
moved_forward_cnodes->clear();
moved_backward_cnodes->clear();
return nullptr;
}
}
}
return first_output_node_ptr;
}
void AscendStreamAssign::FinetuneSubgraphExecOrder(vector<CNodePtr> *cnodes) {
MS_EXCEPTION_IF_NULL(cnodes);
auto hcom_pos = find_if(cnodes->begin(), cnodes->end(),
[](CNodePtr &node_ptr) -> bool { return AnfAlgo::GetCNodeName(node_ptr) == "AllReduce"; });
if (hcom_pos == cnodes->end()) {
cnodes->clear();
return;
}
CNodePtr hcom_ptr = *hcom_pos;
vector<CNodePtr> ori_cnodes(cnodes->begin(), cnodes->end());
cnodes->clear();
vector<CNodePtr> atomic_addr_clean;
for (auto iter = ori_cnodes.begin(); iter < ori_cnodes.end(); iter++) {
if (AnfAlgo::GetCNodeName(*iter) == kAtomicAddrCleanOpName) {
atomic_addr_clean.emplace_back(*iter);
continue;
}
auto inputs = GetInputKernels(*iter);
auto last_input_pos = cnodes->end();
for (auto &input : inputs) {
auto pos = find(cnodes->begin(), cnodes->end(), input);
if (pos != cnodes->end()) {
last_input_pos = (last_input_pos == cnodes->end() || last_input_pos < pos) ? pos : last_input_pos;
}
}
if (last_input_pos == cnodes->end()) {
auto hcom_it = find(cnodes->begin(), cnodes->end(), hcom_ptr);
if (hcom_it == cnodes->end() || AnfAlgo::GetCNodeName(*iter) == kLabelGotoOpName ||
AnfAlgo::GetCNodeName(*iter) == kLabelSetOpName || AnfAlgo::GetCNodeName(*iter) == kLabelSwitchOpName) {
cnodes->emplace_back(*iter);
} else {
cnodes->insert(hcom_it, *iter);
}
} else {
cnodes->insert(last_input_pos + 1, *iter);
}
}
for (auto &node : atomic_addr_clean) {
auto inputs = GetInputKernels(node);
auto first_input_pos = cnodes->end();
for (auto &input : inputs) {
auto pos = find(cnodes->begin(), cnodes->end(), input);
first_input_pos = (first_input_pos == cnodes->end() || first_input_pos > pos) ? pos : first_input_pos;
}
if (first_input_pos == cnodes->end()) {
MS_LOG(DEBUG) << "node: " << node->fullname_with_scope() << " 's input was not found";
cnodes->clear();
return;
} else {
cnodes->insert(first_input_pos, node);
}
}
if (cnodes->size() != ori_cnodes.size()) {
MS_LOG(DEBUG) << "Total number inconsistent, original node size: " << ori_cnodes.size()
<< ", while the new size after finetune order is: " << cnodes->size();
cnodes->clear();
return;
}
}
// performance optimization for trailing time in distribute mode
// allreduce of the last batch of gradients and the optimizer can be done parallel
void AscendStreamAssign::TrailingTimeOptimizationByReorder(const NotNull<KernelGraphPtr> &graph_ptr) {
MS_LOG(INFO) << "Trailing time optimization begin";
vector<CNodePtr> last_grad_and_status;
CheckScenario(graph_ptr, &last_grad_and_status);
if (last_grad_and_status.empty()) {
MS_LOG(INFO) << "Unsuitable scenario, no optimization needed";
return;
}
auto cnode_ptr_list = graph_ptr->execution_order();
vector<CNodePtr> moved_forward_cnodes;
vector<CNodePtr> moved_backward_cnodes;
CNodePtr first_output_ptr =
GetCNodesNeededMoved(&moved_backward_cnodes, &moved_forward_cnodes, last_grad_and_status, graph_ptr);
if (moved_backward_cnodes.empty() || first_output_ptr == nullptr) {
MS_LOG(INFO) << "Unsuitable scenario, no optimization needed";
return;
}
uint32_t subgraph_id = AnfAlgo::GetGraphId(first_output_ptr.get());
auto last_grad_ptr = last_grad_and_status[0];
auto last_grad_pos = find(cnode_ptr_list.begin(), cnode_ptr_list.end(), last_grad_ptr);
vector<CNodePtr> cnodes(cnode_ptr_list.begin(), last_grad_pos + 1);
cnodes.insert(cnodes.end(), moved_forward_cnodes.begin(), moved_forward_cnodes.end());
auto pos = last_grad_pos + moved_forward_cnodes.size() + moved_backward_cnodes.size() + 1;
while (pos < cnode_ptr_list.end() && AnfAlgo::GetGraphId((*pos).get()) != subgraph_id) {
cnodes.emplace_back(*pos);
pos++;
}
vector<CNodePtr> subgraph_cnodes;
while (pos < cnode_ptr_list.end() && AnfAlgo::GetGraphId((*pos).get()) == subgraph_id) {
if (*pos != first_output_ptr) {
subgraph_cnodes.emplace_back(*pos);
} else {
subgraph_cnodes.insert(subgraph_cnodes.end(), moved_backward_cnodes.begin(), moved_backward_cnodes.end());
subgraph_cnodes.emplace_back(*pos);
}
pos++;
}
FinetuneSubgraphExecOrder(&subgraph_cnodes);
if (subgraph_cnodes.empty()) {
MS_LOG(INFO) << "Finetune subgraph execute order failed, no optimization needed";
return;
}
cnodes.insert(cnodes.end(), subgraph_cnodes.begin(), subgraph_cnodes.end());
cnodes.insert(cnodes.end(), pos, cnode_ptr_list.end());
if (cnodes.size() != cnode_ptr_list.size()) {
MS_LOG(INFO) << "Inconsistent cnodes number. Original size: " << cnode_ptr_list.size()
<< ", while new order cnodes size: " << cnodes.size();
return;
}
for (auto &node : subgraph_cnodes) {
AnfAlgo::SetGraphId(subgraph_id, node.get());
}
graph_ptr->set_execution_order(cnodes);
MS_LOG(INFO) << "Trailing time optimization end";
}
// section 2
void AscendStreamAssign::AssignAllNodesStream(const NotNull<KernelGraphPtr> &graph_ptr) {
auto cnode_ptr_list = graph_ptr->execution_order();

View File

@ -161,6 +161,13 @@ class AscendStreamAssign {
void GetProcessedStream(const NotNull<KernelGraphPtr> &graph_ptr);
void GetNeedActiveStreams(const NotNull<KernelGraphPtr> &graph_ptr);
void ReorderIndependentOrders(const NotNull<KernelGraphPtr> &graph_ptr);
void CheckScenario(const NotNull<KernelGraphPtr> &graph_ptr, vector<CNodePtr> *last_grad_and_status);
CNodePtr GetCNodesNeededMoved(vector<CNodePtr> *moved_backward_cnodes, vector<CNodePtr> *moved_forward_cnodes,
const vector<CNodePtr> &last_grad_and_status, const NotNull<KernelGraphPtr> &graph_ptr);
void FinetuneSubgraphExecOrder(vector<CNodePtr> *cnodes);
void TrailingTimeOptimizationByReorder(const NotNull<KernelGraphPtr> &graph_ptr);
uint32_t GetMaxIndexTarget(const NotNull<KernelGraphPtr> &graph_ptr);
uint32_t GetIndexByKey(const NotNull<KernelGraphPtr> &graph_ptr, const CNodeKey &key);
uint32_t GetIndependentStreamSwitchStreamId(const NotNull<KernelGraphPtr> &graph_ptr);