event for kernel by kernel

This commit is contained in:
baihuawei 2022-05-10 14:37:40 +08:00
parent 791cd0e237
commit 8a0595305c
14 changed files with 436 additions and 191 deletions

View File

@ -106,8 +106,8 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph {
input_nodes_ = graph.input_nodes_;
pre_graphs_ = graph.pre_graphs_;
post_graphs_ = graph.post_graphs_;
allreduce_from_send_recv_pairs_ = graph.allreduce_from_send_recv_pairs_;
allreduce_to_send_recv_pairs_ = graph.allreduce_to_send_recv_pairs_;
send_recv_pairs_for_parallel_op_inputs_ = graph.send_recv_pairs_for_parallel_op_inputs_;
send_recv_pairs_for_parallel_op_outputs_ = graph.send_recv_pairs_for_parallel_op_outputs_;
size_t pre_graph_finished_count = graph.pre_graph_finished_count_;
pre_graph_finished_count_ = pre_graph_finished_count;
size_t post_graph_finished_count = graph.post_graph_finished_count_;
@ -375,18 +375,34 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph {
}
// end of handle graph dependency
// The interface of allreduce send/recv pairs map.
void InsertFromSendRecvPair(const CNodePtr &allreduce, const std::pair<CNodePtr, CNodePtr> &send_recv_pair) {
allreduce_from_send_recv_pairs_[allreduce] = send_recv_pair;
// The interface of parallel op send/recv pairs map.
void InsertSendRecvPairForParallelOpInputs(const CNodePtr &parallel_op,
const std::pair<CNodePtr, CNodePtr> &send_recv_pair) {
auto iter = send_recv_pairs_for_parallel_op_inputs_.find(parallel_op);
if (iter == send_recv_pairs_for_parallel_op_inputs_.end()) {
send_recv_pairs_for_parallel_op_inputs_[parallel_op] = {send_recv_pair};
} else {
iter->second.emplace_back(send_recv_pair);
}
}
void InsertToSendRecvPair(const CNodePtr &allreduce, const std::pair<CNodePtr, CNodePtr> &send_recv_pair) {
allreduce_to_send_recv_pairs_[allreduce] = send_recv_pair;
void InsertSendRecvPairForParallelOpOutputs(const CNodePtr &parallel_op,
const std::pair<CNodePtr, CNodePtr> &send_recv_pair) {
auto iter = send_recv_pairs_for_parallel_op_outputs_.find(parallel_op);
if (iter == send_recv_pairs_for_parallel_op_outputs_.end()) {
send_recv_pairs_for_parallel_op_outputs_[parallel_op] = {send_recv_pair};
} else {
iter->second.emplace_back(send_recv_pair);
}
}
const mindspore::HashMap<CNodePtr, std::pair<CNodePtr, CNodePtr>> &allreduce_from_send_recv_pairs() const {
return allreduce_from_send_recv_pairs_;
const mindspore::HashMap<CNodePtr, std::vector<std::pair<CNodePtr, CNodePtr>>>
&send_recv_pairs_for_parallel_op_inputs() const {
return send_recv_pairs_for_parallel_op_inputs_;
}
const mindspore::HashMap<CNodePtr, std::pair<CNodePtr, CNodePtr>> &allreduce_to_send_recv_pairs() const {
return allreduce_to_send_recv_pairs_;
const mindspore::HashMap<CNodePtr, std::vector<std::pair<CNodePtr, CNodePtr>>>
&send_recv_pairs_for_parallel_op_outputs() const {
return send_recv_pairs_for_parallel_op_outputs_;
}
uint32_t label_num() const { return label_num_; }
@ -521,10 +537,11 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph {
std::map<session::KernelWithIndex, session::KernelWithIndex, session::KernelWithIndexCmp> nop_node_output_map_;
mindspore::HashMap<uint32_t, std::weak_ptr<session::KernelGraph>> pre_graphs_;
mindspore::HashMap<uint32_t, std::weak_ptr<session::KernelGraph>> post_graphs_;
// The send/recv pairs inserted for allreduce, the key is allreduce kernel, the first of pair is send node, the second
// of pair is recv node.
mindspore::HashMap<CNodePtr, std::pair<CNodePtr, CNodePtr>> allreduce_from_send_recv_pairs_;
mindspore::HashMap<CNodePtr, std::pair<CNodePtr, CNodePtr>> allreduce_to_send_recv_pairs_;
// key:parallel op ptr, value:vector of <send op receive op > pairs
mindspore::HashMap<CNodePtr, std::vector<std::pair<CNodePtr, CNodePtr>>> send_recv_pairs_for_parallel_op_inputs_;
// key:parallel op ptr, value:vector of <send op receive op > pairs
mindspore::HashMap<CNodePtr, std::vector<std::pair<CNodePtr, CNodePtr>>> send_recv_pairs_for_parallel_op_outputs_;
std::atomic<size_t> pre_graph_finished_count_{0};
std::atomic<size_t> post_graph_finished_count_{0};
bool first_step_{true};

View File

@ -743,6 +743,29 @@ bool AscendKernelRuntime::Run(const session::KernelGraph &graph, bool is_task_si
return ret;
}
void AscendKernelRuntime::SetKernelModRtStream(const NotNull<KernelGraphPtr> &graph_ptr) {
const auto &kernels = graph_ptr->execution_order();
for (size_t i = 0; i < kernels.size(); ++i) {
auto &node = kernels[i];
auto kernel_mod = AnfAlgo::GetKernelMod(node);
auto ascend_kernel_mod = dynamic_cast<kernel::AscendKernelMod *>(kernel_mod);
MS_EXCEPTION_IF_NULL(ascend_kernel_mod);
auto stream_id = AnfAlgo::GetStreamId(kernels[i]);
auto iter = stream_id_map_.find(stream_id);
if (iter == stream_id_map_.end()) {
void *stream = nullptr;
auto ret = rtStreamCreate(&stream, 0);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "create communication stream failed, ret:" << ret;
}
stream_id_map_[stream_id] = stream;
ascend_kernel_mod->set_stream(stream);
} else {
ascend_kernel_mod->set_stream(iter->second);
}
}
}
void AscendKernelRuntime::SetKernelModStream(const std::vector<CNodePtr> &kernels,
std::vector<size_t> *last_stream_nodes) {
std::map<void *, size_t> last_kernel;
@ -872,91 +895,6 @@ void AscendKernelRuntime::GenKernelEvents(const session::KernelGraph &graph) {
graph_kernel_events_map_[graph.graph_id()] = std::move(kernel_events);
}
void AscendKernelRuntime::GenKernelEventsForMindRT(const session::KernelGraph &graph) {
auto &kernels = graph.execution_order();
if (kernels.empty() || graph_kernel_events_map_.find(graph.graph_id()) != graph_kernel_events_map_.end()) {
return;
}
std::vector<size_t> last_stream_nodes;
SetKernelModStream(kernels, &last_stream_nodes);
auto kernel_events = std::pair<std::map<AnfNodePtr, std::vector<std::function<void()>>>,
std::map<AnfNodePtr, std::vector<std::function<void()>>>>();
auto &kernel_pre_run_events = kernel_events.first;
auto &kernel_post_run_events = kernel_events.second;
for (size_t i = 0; i < kernels.size(); ++i) {
auto &kernel = kernels[i];
auto curr_stream_id = AnfAlgo::GetStreamId(kernel);
if (stream_id_map_.find(curr_stream_id) == stream_id_map_.end()) {
MS_LOG(EXCEPTION) << "Stream " << curr_stream_id << "has not been created.";
}
auto wait_stream = stream_id_map_[curr_stream_id];
std::vector<AnfNodePtr> used_kernels;
std::set<AnfNodePtr> visited_kernels;
common::AnfAlgo::GetAllVisitedCNode(kernel, &used_kernels, &visited_kernels);
bool found_depend = false;
std::set<AnfNodePtr> record_nodes;
// set events for nodes and its input: [input_node_stream, node_stream]
for (auto &visited : used_kernels) {
auto pre_cnode_stream_id = AnfAlgo::GetStreamId(visited);
if (stream_id_map_.find(pre_cnode_stream_id) == stream_id_map_.end()) {
MS_LOG(EXCEPTION) << "Stream " << pre_cnode_stream_id << "has not been created.";
}
if (pre_cnode_stream_id == curr_stream_id) {
found_depend = true;
continue;
}
if (record_nodes.find(visited) == record_nodes.end()) {
found_depend = true;
auto record_stream = stream_id_map_[pre_cnode_stream_id];
auto event = CreateDeviceEvent();
event->set_wait_stream(wait_stream);
event->set_record_stream(record_stream);
kernel_post_run_events[visited].emplace_back([event]() { event->RecordEvent(); });
kernel_pre_run_events[kernel].emplace_back([event]() { event->WaitEvent(); });
}
record_nodes.insert(visited);
}
// for start_node(no inputs), set event [stream_, start_node_stream]
if (!found_depend && wait_stream != stream_) {
auto pre_event = CreateDeviceEvent();
pre_event->set_wait_stream(wait_stream);
pre_event->set_record_stream(stream_);
kernel_pre_run_events[kernel].emplace_back([pre_event]() { pre_event->RecordEvent(); });
kernel_pre_run_events[kernel].emplace_back([pre_event]() { pre_event->WaitEvent(); });
}
}
// find end node of graph by last_stream_nodes, and set event [last_node_stream, stream_]
ProcessBoundaryEvent(kernels, &kernel_post_run_events, last_stream_nodes);
graph_kernel_events_map_[graph.graph_id()] = std::move(kernel_events);
}
std::pair<vector<std::function<void()>>, vector<std::function<void()>>> AscendKernelRuntime::GetKernelEventFuncs(
const CNodePtr &kernel) const {
std::map<AnfNodePtr, std::vector<std::function<void()>>> kernels_pre_event_funcs;
std::map<AnfNodePtr, std::vector<std::function<void()>>> kernels_post_event_funcs;
std::vector<std::function<void()>> kernel_pre_event_funcs;
std::vector<std::function<void()>> kernel_post_event_funcs;
auto graph_id = AnfAlgo::GetGraphId(kernel.get());
auto events_iter = graph_kernel_events_map_.find(graph_id);
if (events_iter != graph_kernel_events_map_.end()) {
kernels_pre_event_funcs = events_iter->second.first;
kernels_post_event_funcs = events_iter->second.second;
}
auto pre_event_funcs_iter = kernels_pre_event_funcs.find(kernel);
if (pre_event_funcs_iter != kernels_pre_event_funcs.end()) {
kernel_pre_event_funcs = pre_event_funcs_iter->second;
}
auto post_event_funcs_iter = kernels_post_event_funcs.find(kernel);
if (post_event_funcs_iter != kernels_post_event_funcs.end()) {
kernel_post_event_funcs = post_event_funcs_iter->second;
}
return std::make_pair(kernel_pre_event_funcs, kernel_post_event_funcs);
}
void AscendKernelRuntime::ProcessBoundaryEvent(
const std::vector<CNodePtr> &kernels, std::map<AnfNodePtr, std::vector<std::function<void()>>> *kernel_run_events,
const std::vector<size_t> &last_stream_nodes) {

View File

@ -44,11 +44,9 @@ class AscendKernelRuntime : public KernelRuntime {
bool Init() override;
bool LoadData(const session::KernelGraph &graph) override;
bool GenTask(const session::KernelGraph &graph);
void GenKernelEventsForMindRT(const session::KernelGraph &graph);
void GenKernelEvents(const session::KernelGraph &graph) override;
std::pair<vector<std::function<void()>>, vector<std::function<void()>>> GetKernelEventFuncs(
const CNodePtr &kernel) const;
void SetKernelModStream(const std::vector<CNodePtr> &kernels, std::vector<size_t> *last_stream_nodes);
void SetKernelModRtStream(const NotNull<KernelGraphPtr> &graph_ptr);
void ProcessBoundaryEvent(const std::vector<CNodePtr> &kernels,
std::map<AnfNodePtr, std::vector<std::function<void()>>> *kernel_run_events,
const std::vector<size_t> &last_stream_nodes);

View File

@ -217,28 +217,11 @@ void AscendStreamAssign::AssignStreamForNonTaskSink(const std::vector<CNodePtr>
if (kernels.empty()) {
return;
}
if (stream_groups_.empty()) {
stream_groups_.emplace_back(std::vector<uint32_t>{kDefaultStreamIndex});
stream_groups_.emplace_back(std::vector<uint32_t>{kIndependentStreamIndex});
stream_groups_.emplace_back(std::vector<uint32_t>{kWorldGroupStreamIndex});
}
group_stream_id_map_[kHcclWorldGroup] = kWorldGroupStreamIndex;
for (size_t i = 0; i < kernels.size(); ++i) {
auto &node = kernels[i];
if (common::AnfAlgo::IsCommunicationOp(node)) {
auto group = common::AnfAlgo::GetNodeAttr<std::string>(node, kAttrGroup);
auto iter = group_stream_id_map_.find(group);
if (iter == group_stream_id_map_.end()) {
auto id = SizeToUint(group_stream_id_map_.size()) + kWorldGroupStreamIndex;
group_stream_id_map_[group] = id;
AnfAlgo::SetStreamId(id, node.get());
stream_groups_.emplace_back(std::vector<uint32_t>{id});
} else {
auto id = iter->second;
AnfAlgo::SetStreamId(id, node.get());
}
} else if (AnfAlgo::IsIndependentNode(node)) {
AnfAlgo::SetStreamId(kIndependentStreamIndex, node.get());
AnfAlgo::SetStreamId(kWorldGroupStreamIndex, node.get());
} else {
AnfAlgo::SetStreamId(kDefaultStreamIndex, node.get());
}
@ -251,6 +234,234 @@ void AscendStreamAssign::AssignStreamForNonTaskSink(const std::vector<CNodePtr>
}
}
void GenKernelIoExecInfoMap(const NotNull<KernelGraphPtr> &kernel_graph,
mindspore::HashMap<CNodePtr, NodeIoExecInfoPtr> *kernel_io_exec_info_map) {
auto &exec_kernels = kernel_graph->execution_order();
for (size_t i = 0; i < exec_kernels.size(); ++i) {
auto &process_kernel = exec_kernels[i];
MS_EXCEPTION_IF_NULL(process_kernel);
auto process_exec_info = std::make_shared<NodeExecInfo>();
MS_EXCEPTION_IF_NULL(process_exec_info);
process_exec_info->node = process_kernel;
process_exec_info->stream_id = AnfAlgo::GetStreamId(process_kernel);
process_exec_info->execution_order_index = i;
auto process_io_exec_info = std::make_shared<NodeIoExecInfo>();
MS_EXCEPTION_IF_NULL(process_io_exec_info);
process_io_exec_info->node_exec_info = process_exec_info;
process_io_exec_info->inputs = {};
process_io_exec_info->outputs = {};
(*kernel_io_exec_info_map)[process_kernel] = process_io_exec_info;
}
for (auto &process_kernel : exec_kernels) {
MS_EXCEPTION_IF_NULL(process_kernel);
auto process_iter = kernel_io_exec_info_map->find(process_kernel);
if (process_iter == kernel_io_exec_info_map->end()) {
MS_LOG(ERROR) << "Can't get kernel io execution info for " << process_kernel->fullname_with_scope();
continue;
}
auto process_io_exec_info = process_iter->second;
auto process_exec_info = process_iter->second->node_exec_info;
auto inputs = process_kernel->inputs();
for (size_t i = 1; i < inputs.size(); i++) {
auto input_node = common::AnfAlgo::VisitKernelWithReturnType(inputs[i], 0).first;
MS_EXCEPTION_IF_NULL(input_node);
if (AnfUtils::IsRealCNodeKernel(input_node)) {
auto input_kernel = input_node->cast<CNodePtr>();
auto iter = kernel_io_exec_info_map->find(input_kernel);
if (iter == kernel_io_exec_info_map->end()) {
MS_LOG(ERROR) << "Can't get kernel io execution info for " << process_kernel->fullname_with_scope()
<< "'s input node " << input_kernel->fullname_with_scope();
continue;
}
auto input_io_exec_info = iter->second;
auto input_exec_info = iter->second->node_exec_info;
process_io_exec_info->inputs.push_back(input_exec_info);
input_io_exec_info->outputs.push_back(process_exec_info);
}
}
}
}
void AscendStreamAssign::InsertEventsForInputs(const NotNull<KernelGraphPtr> &kernel_graph, const CNodePtr &kernel,
const NodeIoExecInfoPtr &io_exec_info,
mindspore::HashMap<AnfNodePtr, std::vector<CNodePtr>> *kernel_send,
mindspore::HashMap<AnfNodePtr, std::vector<CNodePtr>> *kernel_recv) {
auto process_stream_id = AnfAlgo::GetStreamId(kernel);
auto input_exec_info_list = io_exec_info->inputs;
mindspore::HashMap<uint32_t, NodeExecInfoPtr> stream_max_exec_node_map;
for (auto &input : input_exec_info_list) {
MS_EXCEPTION_IF_NULL(input);
auto input_stream_id = input->stream_id;
auto iter = stream_max_exec_node_map.find(input_stream_id);
if (iter == stream_max_exec_node_map.end()) {
stream_max_exec_node_map[input_stream_id] = input;
} else {
MS_EXCEPTION_IF_NULL(iter->second);
if (input->execution_order_index > iter->second->execution_order_index) {
iter->second = input;
}
}
}
for (auto input_exec : stream_max_exec_node_map) {
MS_EXCEPTION_IF_NULL(input_exec.second);
if (input_exec.second->stream_id == process_stream_id) {
continue;
}
InsertEvents(kernel_graph, kernel, input_exec.second->node, kernel_send, kernel_recv, kernel);
}
}
void AscendStreamAssign::InsertEventsForOutputs(const NotNull<KernelGraphPtr> &kernel_graph, const CNodePtr &kernel,
const NodeIoExecInfoPtr &io_exec_info,
mindspore::HashMap<AnfNodePtr, std::vector<CNodePtr>> *kernel_send,
mindspore::HashMap<AnfNodePtr, std::vector<CNodePtr>> *kernel_recv) {
auto process_stream_id = AnfAlgo::GetStreamId(kernel);
auto output_exec_info_list = io_exec_info->outputs;
mindspore::HashMap<uint32_t, NodeExecInfoPtr> stream_min_exec_node_map;
for (auto &output : output_exec_info_list) {
MS_EXCEPTION_IF_NULL(output);
auto output_stream_id = output->stream_id;
auto iter = stream_min_exec_node_map.find(output_stream_id);
if (iter == stream_min_exec_node_map.end()) {
stream_min_exec_node_map[output_stream_id] = output;
} else {
MS_EXCEPTION_IF_NULL(iter->second);
if (output->execution_order_index < iter->second->execution_order_index) {
iter->second = output;
}
}
}
for (auto output_exec : stream_min_exec_node_map) {
MS_EXCEPTION_IF_NULL(output_exec.second);
if (output_exec.second->stream_id == process_stream_id) {
continue;
}
InsertEvents(kernel_graph, kernel, kernel, kernel_send, kernel_recv, output_exec.second->node);
}
// parallel op has output tensor, and it didn't connect to other kernel, it's output is graph output, sync it.
if (output_exec_info_list.empty() && (common::AnfAlgo::GetOutputTensorNum(kernel) != 0)) {
InsertEvents(kernel_graph, kernel, kernel, kernel_send, kernel_recv, kernel_graph->output());
}
}
void AscendStreamAssign::InsertEvents(const NotNull<KernelGraphPtr> &kernel_graph, const CNodePtr &parallel_cnode,
const AnfNodePtr &node_before_send,
mindspore::HashMap<AnfNodePtr, std::vector<CNodePtr>> *kernel_send,
mindspore::HashMap<AnfNodePtr, std::vector<CNodePtr>> *kernel_recv,
const AnfNodePtr &node_after_recv) {
AscendStreamMng &resource_manager = AscendStreamMng::GetInstance();
uint32_t event_id = resource_manager.ApplyNewEvent();
auto event = resource_manager.ApplyRtEvent();
auto send_cnode = CreateSendApplyKernel(kernel_graph, event_id, AnfAlgo::GetStreamId(node_before_send));
common::AnfAlgo::SetNodeAttr(kAttrRecordEvent, MakeValue(reinterpret_cast<uintptr_t>(event)), send_cnode);
auto send_iter = kernel_send->find(node_before_send);
if (send_iter == kernel_send->end()) {
(*kernel_send)[node_before_send] = {send_cnode};
} else {
send_iter->second.push_back(send_cnode);
}
CNodePtr recv_cnode = CreateRecvApplyKernel(kernel_graph, event_id, AnfAlgo::GetStreamId(node_after_recv));
common::AnfAlgo::SetNodeAttr(kAttrWaitEvent, MakeValue(reinterpret_cast<uintptr_t>(event)), recv_cnode);
auto process_iter = kernel_recv->find(node_after_recv);
if (process_iter == kernel_recv->end()) {
(*kernel_recv)[node_after_recv] = {recv_cnode};
} else {
process_iter->second.push_back(recv_cnode);
}
if (parallel_cnode == node_before_send) {
kernel_graph->InsertSendRecvPairForParallelOpOutputs(parallel_cnode, std::make_pair(send_cnode, recv_cnode));
MS_LOG(INFO) << "Generate send/recv for parallel op " << parallel_cnode->fullname_with_scope() << "'s output."
<< "Send node " << send_cnode->fullname_with_scope() << " after "
<< node_before_send->fullname_with_scope() << ", recv node " << recv_cnode->fullname_with_scope()
<< " before " << node_after_recv->fullname_with_scope();
} else {
kernel_graph->InsertSendRecvPairForParallelOpInputs(parallel_cnode, std::make_pair(send_cnode, recv_cnode));
MS_LOG(INFO) << "Generate send/recv for parallel op " << parallel_cnode->fullname_with_scope() << "'s input."
<< "Send node " << send_cnode->fullname_with_scope() << " after "
<< node_before_send->fullname_with_scope() << ", recv node " << recv_cnode->fullname_with_scope()
<< " before " << node_after_recv->fullname_with_scope();
}
}
void AscendStreamAssign::GenEventsForParallelOp(const NotNull<KernelGraphPtr> &kernel_graph,
mindspore::HashMap<AnfNodePtr, std::vector<CNodePtr>> *kernel_send,
mindspore::HashMap<AnfNodePtr, std::vector<CNodePtr>> *kernel_recv) {
MS_LOG(DEBUG) << "Start GenEventsForParallelOp...";
auto exec_kernels = kernel_graph->execution_order();
mindspore::HashMap<CNodePtr, NodeIoExecInfoPtr> kernel_io_exec_info_map;
GenKernelIoExecInfoMap(kernel_graph, &kernel_io_exec_info_map);
for (auto &process_kernel : exec_kernels) {
MS_EXCEPTION_IF_NULL(process_kernel);
auto process_stream_id = AnfAlgo::GetStreamId(process_kernel);
if (process_stream_id == kDefaultStreamIndex) {
continue;
}
MS_LOG(DEBUG) << "Start GenEvents For ParallelOp " << process_kernel->fullname_with_scope();
auto process_iter = kernel_io_exec_info_map.find(process_kernel);
if (process_iter == kernel_io_exec_info_map.end()) {
MS_LOG(ERROR) << "Can't get node io execution info for " << process_kernel->fullname_with_scope();
continue;
}
auto process_io_exec_info = process_iter->second;
InsertEventsForInputs(kernel_graph, process_kernel, process_io_exec_info, kernel_send, kernel_recv);
InsertEventsForOutputs(kernel_graph, process_kernel, process_io_exec_info, kernel_send, kernel_recv);
}
MS_LOG(DEBUG) << "Finish GenEventsForParallelOp.";
}
void AscendStreamAssign::UpdateEventsToExecutionOrder(
const NotNull<KernelGraphPtr> &kernel_graph,
const mindspore::HashMap<AnfNodePtr, std::vector<CNodePtr>> &send_after_node,
const mindspore::HashMap<AnfNodePtr, std::vector<CNodePtr>> &recv_before_node) const {
MS_LOG(DEBUG) << "Start UpdateEventsToExecutionOrder...";
auto exec_kernels = kernel_graph->execution_order();
std::vector<CNodePtr> new_exec_orders;
for (auto &kernel : exec_kernels) {
auto before_iter = recv_before_node.find(kernel);
if (before_iter != recv_before_node.end()) {
for (auto &recv : before_iter->second) {
new_exec_orders.push_back(recv);
}
}
new_exec_orders.push_back(kernel);
auto after_iter = send_after_node.find(kernel);
if (after_iter != send_after_node.end()) {
for (auto send : after_iter->second) {
new_exec_orders.push_back(send);
}
}
}
auto graph_output = kernel_graph->output();
auto graph_output_iter = recv_before_node.find(graph_output);
if (graph_output_iter != recv_before_node.end()) {
for (auto &recv : graph_output_iter->second) {
new_exec_orders.push_back(recv);
}
}
kernel_graph->set_execution_order(new_exec_orders);
MS_LOG(DEBUG) << "Finish UpdateEventsToExecutionOrder.";
}
void AscendStreamAssign::InsertEventForNonTaskSink(const NotNull<KernelGraphPtr> &kernel_graph) {
mindspore::HashMap<AnfNodePtr, std::vector<CNodePtr>> kernel_send;
mindspore::HashMap<AnfNodePtr, std::vector<CNodePtr>> kernel_recv;
AnfAlgo::SetStreamId(kDefaultStreamIndex, kernel_graph->output().get());
GenEventsForParallelOp(kernel_graph, &kernel_send, &kernel_recv);
UpdateEventsToExecutionOrder(kernel_graph, kernel_send, kernel_recv);
InsertEventForMicroBatchIndependent(kernel_graph);
}
void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr) {
if (graph_ptr->is_dynamic_shape()) {
MS_LOG(WARNING) << "Dynamic shape do not need to assign stream.";
@ -263,6 +474,7 @@ void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr)
if (!IsTaskSink()) {
auto kernels = graph_ptr->execution_order();
AssignStreamForNonTaskSink(kernels);
InsertEventForNonTaskSink(graph_ptr);
MS_LOG(INFO) << "After finish stream assign";
graph_ptr->PrintGraphExecuteOrder();
PROF_END(assign_stream);
@ -2535,6 +2747,7 @@ void AscendStreamAssign::InsertEventForMicroBatchIndependent(const NotNull<Kerne
uint32_t cur_event_id = resource_manager.ApplyNewEvent();
CNodePtr send_cnode = CreateSendApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId((cnode)));
CNodePtr recv_cnode = CreateRecvApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId(next_gen_mask));
graph_ptr->InsertSendRecvPairForParallelOpInputs(next_gen_mask, std::make_pair(send_cnode, recv_cnode));
node_send_map[cnode] = send_cnode;
node_recv_map[next_gen_mask] = recv_cnode;
}

View File

@ -47,6 +47,21 @@ using GroupGraphMap = std::map<std::string, std::map<uint32_t, std::vector<CNode
const uint32_t kInvalidStreamId = UINT32_MAX;
const uint32_t kInvalidEventId = UINT32_MAX;
enum StreamActiveKind { kInvalid = 0, kHead, kMiddle, kTail };
struct NodeExecInfo {
CNodePtr node;
uint32_t stream_id;
size_t execution_order_index;
};
using NodeExecInfoPtr = std::shared_ptr<NodeExecInfo>;
struct NodeIoExecInfo {
NodeExecInfoPtr node_exec_info;
std::vector<NodeExecInfoPtr> inputs;
std::vector<NodeExecInfoPtr> outputs;
};
using NodeIoExecInfoPtr = std::shared_ptr<NodeIoExecInfo>;
class AscendStreamAssign {
public:
static AscendStreamAssign &GetInstance() {
@ -191,6 +206,27 @@ class AscendStreamAssign {
uint32_t max_stream_count_ = 0;
uint32_t max_task_count_ = 0;
// insert event for kernel by kernel
void InsertEventForNonTaskSink(const NotNull<KernelGraphPtr> &kernel_graph);
void GenEventsForParallelOp(const NotNull<KernelGraphPtr> &kernel_graph,
HashMap<AnfNodePtr, vector<CNodePtr>> *kernel_send,
HashMap<AnfNodePtr, vector<CNodePtr>> *kernel_recv);
void UpdateEventsToExecutionOrder(const NotNull<KernelGraphPtr> &kernel_graph,
const HashMap<AnfNodePtr, vector<CNodePtr>> &recv_before_node,
const HashMap<AnfNodePtr, vector<CNodePtr>> &send_after_node) const;
void InsertEventsForInputs(const NotNull<KernelGraphPtr> &kernel_graph, const CNodePtr &kernel,
const NodeIoExecInfoPtr &io_exec_info, HashMap<AnfNodePtr, vector<CNodePtr>> *kernel_send,
HashMap<AnfNodePtr, vector<CNodePtr>> *kernel_recv);
void InsertEventsForOutputs(const NotNull<KernelGraphPtr> &kernel_graph, const CNodePtr &kernel,
const NodeIoExecInfoPtr &io_exec_info, HashMap<AnfNodePtr, vector<CNodePtr>> *kernel_send,
HashMap<AnfNodePtr, vector<CNodePtr>> *kernel_recv);
void InsertEvents(const NotNull<KernelGraphPtr> &kernel_graph, const CNodePtr &parallel_cnode,
const AnfNodePtr &node_before_send, HashMap<mindspore::AnfNodePtr, vector<CNodePtr>> *kernel_send,
HashMap<AnfNodePtr, vector<CNodePtr>> *kernel_recv, const AnfNodePtr &node_after_recv);
};
} // namespace ascend
} // namespace device

View File

@ -17,6 +17,10 @@
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_STREAM_MANAGER_H_
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_STREAM_MANAGER_H_
#include <memory>
#include "utils/hash_map.h"
#include "runtime/event.h"
namespace mindspore {
namespace device {
namespace ascend {
@ -36,6 +40,16 @@ class AscendStreamMng {
uint32_t ApplyNewEvent() { return cur_event_num_++; }
rtEvent_t ApplyRtEvent() {
auto rt_resource = std::make_shared<rtEvent_t>();
auto ret = rtEventCreate(rt_resource.get());
if (ret != RT_ERROR_NONE) {
MS_LOG(ERROR) << "rtEventCreate failed, ret:" << ret;
*rt_resource = nullptr;
}
return *rt_resource;
}
void DeleteEvent() {
if (!cur_event_num_) {
MS_LOG(WARNING) << "total event num is 0, no event to delete";

View File

@ -367,13 +367,6 @@ void AscendDeviceContext::UpdateExecOrder(const KernelGraphPtr &graph) const {
node_atomics_.clear();
}
void AscendDeviceContext::GenKernelEvents(const NotNull<KernelGraphPtr> &root_graph) const {
MS_LOG(INFO) << "Start GenKernelEvents for graph " << root_graph->graph_id();
MS_EXCEPTION_IF_NULL(runtime_instance_);
runtime_instance_->GenKernelEventsForMindRT(*root_graph.get());
MS_LOG(INFO) << "Finish!";
}
void AscendDeviceContext::SetAtomicCleanToNodes(const KernelGraphPtr &graph,
const std::map<CNodePtr, std::vector<CNodePtr>> &atomics_node) const {
// don't clear node_atomics_ in the end, since atomic_clean_nodes_ in kernel.h is weakptr
@ -422,6 +415,9 @@ void AscendDeviceContext::PreprocessBeforeRunGraph(const KernelGraphPtr &graph)
} else {
PreprocessBeforeRunSingleOpGraph(graph);
AscendStreamAssign::GetInstance().AssignStream(NOT_NULL(graph));
CreateKernel(graph->execution_order());
MS_EXCEPTION_IF_NULL(runtime_instance_);
runtime_instance_->SetKernelModRtStream(NOT_NULL(graph));
}
} catch (const std::exception &e) {
ReportErrorMessage();
@ -825,7 +821,7 @@ void *AscendDeviceContext::GetKernelStream(const CNodePtr &node) const {
auto stream = kernel_mod->stream();
if (stream == nullptr) {
stream = compute_stream_;
MS_LOG(INFO) << "Assign default compute stream for node " << node->fullname_with_scope();
MS_LOG(ERROR) << "Assign default compute stream for node " << node->fullname_with_scope();
}
return stream;
}
@ -945,7 +941,7 @@ bool AscendDeviceContext::LaunchAtomicClean(const CNodePtr &node, const std::vec
// Launch Atomic Node
auto kernel_mod = AnfAlgo::GetKernelMod(atomic_node);
MS_EXCEPTION_IF_NULL(kernel_mod);
return kernel_mod->Launch(atomic_inputs, {}, {}, compute_stream_);
return kernel_mod->Launch(atomic_inputs, {}, {}, GetKernelStream(node));
}
void AscendDeviceContext::InsertEventBeforeRunTask(const KernelGraphPtr &graph) const {

View File

@ -140,7 +140,6 @@ class AscendDeviceContext : public DeviceContext {
static bool IsGraphMode();
bool PySyncRuning() const;
bool MemoryCopyAsync(const CNodePtr &node, const vector<AddressPtr> &inputs, const vector<AddressPtr> &outputs) const;
void GenKernelEvents(const NotNull<KernelGraphPtr> &root_graph) const;
void InsertEventBeforeRunTask(const KernelGraphPtr &graph) const;
void SetAtomicCleanToNodes(const KernelGraphPtr &graph,
const std::map<CNodePtr, std::vector<CNodePtr>> &atomics_node) const;

View File

@ -18,6 +18,7 @@
#include "runtime/stream.h"
#include "utils/ms_context.h"
#include "plugin/device/ascend/hal/device/ge_runtime/task_info.h"
#include "plugin/device/ascend/hal/device/ascend_stream_manager.h"
#include "backend/common/session/anf_runtime_algorithm.h"
#include "include/common/utils/anfalgo.h"
@ -38,18 +39,28 @@ bool RecvKernel::Init(const AnfNodePtr &anf_node) {
MS_LOG(EXCEPTION) << "RecvKernel has no attr kAttrEventId";
}
event_id_ = GetValue<uint32_t>(primitive->GetAttr(kAttrEventId));
if (common::AnfAlgo::HasNodeAttr(kAttrWaitEvent, anf_node->cast<CNodePtr>())) {
event_ = reinterpret_cast<rtEvent_t>(GetValue<uintptr_t>(primitive->GetAttr(kAttrWaitEvent)));
}
MS_LOG(INFO) << "recv op event_id_:" << event_id_;
return true;
}
bool RecvKernel::Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &, void *stream_ptr) {
rtEvent_t stream_event{};
auto status = rtStreamWaitEvent(stream_ptr, stream_event);
MS_EXCEPTION_IF_NULL(event_);
MS_EXCEPTION_IF_NULL(stream_ptr);
auto status = rtStreamWaitEvent(stream_ptr, event_);
if (status != RT_ERROR_NONE) {
MS_LOG(ERROR) << "Recv rtStreamWaitEvent failed!";
return false;
}
status = rtEventReset(event_, stream_ptr);
if (status != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "rtEventReset failed, ret:" << status;
}
return true;
}

View File

@ -37,6 +37,7 @@ class RecvKernel : public RtKernel {
private:
uint32_t event_id_;
rtEvent_t event_;
};
MS_REG_RTKERNEL(streamrecv, RecvKernel);

View File

@ -17,6 +17,7 @@
#include "plugin/device/ascend/kernel/rts/send.h"
#include "runtime/event.h"
#include "plugin/device/ascend/hal/device/ge_runtime/task_info.h"
#include "plugin/device/ascend/hal/device/ascend_stream_manager.h"
#include "backend/common/session/anf_runtime_algorithm.h"
#include "include/common/utils/anfalgo.h"
@ -37,14 +38,19 @@ bool SendKernel::Init(const AnfNodePtr &anf_node) {
MS_LOG(EXCEPTION) << "SendKernel has no attr kAttrEventId";
}
event_id_ = GetValue<uint32_t>(primitive->GetAttr(kAttrEventId));
if (common::AnfAlgo::HasNodeAttr(kAttrRecordEvent, anf_node->cast<CNodePtr>())) {
event_ = reinterpret_cast<rtEvent_t>(GetValue<uintptr_t>(primitive->GetAttr(kAttrRecordEvent)));
}
MS_LOG(INFO) << "send op event id:" << event_id_;
return true;
}
bool SendKernel::Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &, void *stream_ptr) {
rtEvent_t event{};
rtError_t status = rtEventRecord(event, stream_ptr);
MS_EXCEPTION_IF_NULL(event_);
MS_EXCEPTION_IF_NULL(stream_ptr);
rtError_t status = rtEventRecord(event_, stream_ptr);
if (status != RT_ERROR_NONE) {
MS_LOG(ERROR) << "Send op rtEventRecord failed!";
return false;

View File

@ -20,6 +20,7 @@
#include <vector>
#include "plugin/device/ascend/kernel/rts/rt_kernel.h"
#include "plugin/device/ascend/kernel/rts/rt_kernel_info.h"
#include "plugin/device/ascend/hal/device/ascend_event.h"
namespace mindspore {
namespace kernel {
@ -35,6 +36,7 @@ class SendKernel : public RtKernel {
private:
uint32_t event_id_;
rtEvent_t event_;
};
MS_REG_RTKERNEL(streamsend, SendKernel);

View File

@ -222,11 +222,11 @@ void CacheSendRecvCNodesForAllReduce(const std::shared_ptr<session::KernelGraph>
MS_EXCEPTION_IF_NULL(kernel_graph);
std::pair<CNodePtr, CNodePtr> send_recv_nodes(send_node, recv_node);
if (common::AnfAlgo::GetCNodeName(mock_send_node) == kAllReduceOpName) {
kernel_graph->InsertToSendRecvPair(mock_send_node, send_recv_nodes);
kernel_graph->InsertSendRecvPairForParallelOpOutputs(mock_send_node, send_recv_nodes);
}
if (common::AnfAlgo::GetCNodeName(mock_recv_node) == kAllReduceOpName) {
kernel_graph->InsertFromSendRecvPair(mock_recv_node, send_recv_nodes);
kernel_graph->InsertSendRecvPairForParallelOpInputs(mock_recv_node, send_recv_nodes);
}
}
} // namespace gpu

View File

@ -1523,68 +1523,82 @@ void GraphScheduler::LinkControlArrowBySkippedNode(AbstractActor *to_actor, cons
void GraphScheduler::LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
for (auto &from_iter : graph->allreduce_from_send_recv_pairs()) {
auto to_allreduce_node = from_iter.first;
auto from_send_node = from_iter.second.first;
auto from_recv_node = from_iter.second.second;
MS_EXCEPTION_IF_NULL(to_allreduce_node);
MS_EXCEPTION_IF_NULL(from_send_node);
MS_EXCEPTION_IF_NULL(from_recv_node);
MS_LOG(INFO) << "Link control arrow for to_allreduce_node: " << to_allreduce_node->fullname_with_scope();
auto to_allreduce_actor = FetchActor(to_allreduce_node->fullname_with_scope());
auto from_send_actor = FetchActor(from_send_node->fullname_with_scope());
auto from_recv_actor = FetchActor(from_recv_node->fullname_with_scope());
MS_EXCEPTION_IF_NULL(to_allreduce_actor);
MS_EXCEPTION_IF_NULL(from_send_actor);
MS_EXCEPTION_IF_NULL(from_recv_actor);
for (auto &from_iter : graph->send_recv_pairs_for_parallel_op_inputs()) {
auto parallel_node = from_iter.first;
for (auto pair : from_iter.second) {
auto send_node = pair.first;
auto recv_node = pair.second;
MS_EXCEPTION_IF_NULL(parallel_node);
MS_EXCEPTION_IF_NULL(send_node);
MS_EXCEPTION_IF_NULL(recv_node);
MS_LOG(INFO) << "Link control arrow for parallel node input: " << parallel_node->fullname_with_scope();
auto parallel_actor = FetchActor(parallel_node->fullname_with_scope());
auto send_actor = FetchActor(send_node->fullname_with_scope());
auto recv_actor = FetchActor(recv_node->fullname_with_scope());
MS_EXCEPTION_IF_NULL(parallel_actor);
MS_EXCEPTION_IF_NULL(send_actor);
MS_EXCEPTION_IF_NULL(recv_actor);
// inputs of to_allreduce_actor --> from_send_actor
for (auto &input_aid : to_allreduce_actor->input_data_arrow_aids_) {
auto input_actor = dynamic_cast<KernelActor *>(FetchActor(input_aid.Name()));
if (input_actor != nullptr) {
SchedulerHelper::AddControlArrow(input_actor, from_send_actor);
// inputs of to_allreduce_actor --> from_send_actor
for (auto &input_aid : parallel_actor->input_data_arrow_aids_) {
auto input_actor = dynamic_cast<KernelActor *>(FetchActor(input_aid.Name()));
if (input_actor != nullptr) {
SchedulerHelper::AddControlArrow(input_actor, send_actor);
}
}
// from_send_actor --> from_recv_actor
SchedulerHelper::AddControlArrow(send_actor, recv_actor);
// from_recv_actor --> to_allreduce_actor
SchedulerHelper::AddControlArrow(recv_actor, parallel_actor);
}
// from_send_actor --> from_recv_actor
SchedulerHelper::AddControlArrow(from_send_actor, from_recv_actor);
// from_recv_actor --> to_allreduce_actor
SchedulerHelper::AddControlArrow(from_recv_actor, to_allreduce_actor);
}
for (auto &to_iter : graph->allreduce_to_send_recv_pairs()) {
auto from_allreduce_node = to_iter.first;
auto to_send_node = to_iter.second.first;
auto to_recv_node = to_iter.second.second;
MS_EXCEPTION_IF_NULL(from_allreduce_node);
MS_EXCEPTION_IF_NULL(to_send_node);
MS_EXCEPTION_IF_NULL(to_recv_node);
MS_LOG(INFO) << "Link control arrow for from_allreduce_node: " << from_allreduce_node->fullname_with_scope();
auto from_allreduce_actor = FetchActor(from_allreduce_node->fullname_with_scope());
auto to_send_actor = FetchActor(to_send_node->fullname_with_scope());
auto to_recv_actor = dynamic_cast<KernelActor *>(FetchActor(to_recv_node->fullname_with_scope()));
MS_EXCEPTION_IF_NULL(from_allreduce_actor);
MS_EXCEPTION_IF_NULL(to_send_actor);
MS_EXCEPTION_IF_NULL(to_recv_actor);
for (auto &to_iter : graph->send_recv_pairs_for_parallel_op_outputs()) {
auto parallel_node = to_iter.first;
for (auto pair : to_iter.second) {
auto send_node = pair.first;
auto recv_node = pair.second;
MS_EXCEPTION_IF_NULL(parallel_node);
MS_EXCEPTION_IF_NULL(send_node);
MS_EXCEPTION_IF_NULL(recv_node);
MS_LOG(INFO) << "Link control arrow for parallel node output: " << parallel_node->fullname_with_scope();
auto parallel_actor = FetchActor(parallel_node->fullname_with_scope());
auto send_actor = FetchActor(send_node->fullname_with_scope());
auto recv_actor = dynamic_cast<KernelActor *>(FetchActor(recv_node->fullname_with_scope()));
MS_EXCEPTION_IF_NULL(parallel_actor);
MS_EXCEPTION_IF_NULL(send_actor);
MS_EXCEPTION_IF_NULL(recv_actor);
// from_allreduce_actor --> to_send_actor
SchedulerHelper::AddControlArrow(from_allreduce_actor, to_send_actor);
// to_send_actor --> to_recv_actor
SchedulerHelper::AddControlArrow(to_send_actor, to_recv_actor);
// to_recv_actor --> outputs of from_allreduce_actor
for (auto &output_data_arrow : from_allreduce_actor->output_data_arrows_) {
auto output_actor = FetchActor(output_data_arrow->to_op_id_.Name());
if (output_actor != nullptr) {
SchedulerHelper::AddControlArrow(to_recv_actor, output_actor);
// from_allreduce_actor --> to_send_actor
SchedulerHelper::AddControlArrow(parallel_actor, send_actor);
// to_send_actor --> to_recv_actor
SchedulerHelper::AddControlArrow(send_actor, recv_actor);
// to_recv_actor --> outputs of from_allreduce_actor
for (auto &output_data_arrow : parallel_actor->output_data_arrows_) {
auto output_actor = FetchActor(output_data_arrow->to_op_id_.Name());
if (output_actor != nullptr) {
SchedulerHelper::AddControlArrow(recv_actor, output_actor);
}
}
}
// In the scene of allreduce op and computing op parallel multi stream, the input memory of allreduce can be
// reused only when the recv node runs finished, which is expressed by the reference count increased.
for (size_t i = 0; i < common::AnfAlgo::GetInputTensorNum(from_allreduce_node); ++i) {
auto device_tensor = AnfAlgo::GetPrevNodeMutableOutputAddr(from_allreduce_node, i, false);
MS_EXCEPTION_IF_NULL(device_tensor);
UpdateRefCount(device_tensor.get());
(void)to_recv_actor->external_reference_tensors_.emplace_back(device_tensor.get());
// In the scene of allreduce op and computing op parallel multi stream, the input memory of allreduce can be
// reused only when the recv node runs finished, which is expressed by the reference count increased.
for (size_t i = 0; i < common::AnfAlgo::GetInputTensorNum(parallel_node); ++i) {
auto device_tensor = AnfAlgo::GetPrevNodeMutableOutputAddr(parallel_node, i, false);
MS_EXCEPTION_IF_NULL(device_tensor);
UpdateRefCount(device_tensor.get());
(void)recv_actor->external_reference_tensors_.emplace_back(device_tensor.get());
}
auto kernel_mod = AnfAlgo::GetKernelMod(parallel_node);
MS_EXCEPTION_IF_NULL(kernel_mod);
auto workspace_num = kernel_mod->GetWorkspaceSizeList().size();
for (size_t i = 0; i < workspace_num; ++i) {
auto device_tensor = AnfAlgo::GetMutableWorkspaceAddr(parallel_node, i);
MS_EXCEPTION_IF_NULL(device_tensor);
UpdateRefCount(device_tensor.get());
(void)recv_actor->external_reference_tensors_.emplace_back(device_tensor.get());
}
}
}
}