forked from mindspore-Ecosystem/mindspore
commit
52678993c6
|
@ -877,6 +877,64 @@ void AscendKernelRuntime::GenKernelEvents(const session::KernelGraph &graph) {
|
|||
graph_kernel_events_map_[graph.graph_id()] = std::move(kernel_events);
|
||||
}
|
||||
|
||||
void AscendKernelRuntime::GenKernelEventsForMindRT(const session::KernelGraph &graph) {
|
||||
auto &kernels = graph.execution_order();
|
||||
if (kernels.empty() || graph_kernel_events_map_.find(graph.graph_id()) != graph_kernel_events_map_.end()) {
|
||||
return;
|
||||
}
|
||||
std::vector<size_t> last_stream_nodes;
|
||||
SetKernelModStream(kernels, &last_stream_nodes);
|
||||
auto kernel_events = std::pair<std::map<AnfNodePtr, std::vector<std::function<void()>>>,
|
||||
std::map<AnfNodePtr, std::vector<std::function<void()>>>>();
|
||||
auto &kernel_pre_run_events = kernel_events.first;
|
||||
auto &kernel_post_run_events = kernel_events.second;
|
||||
for (size_t i = 0; i < kernels.size(); ++i) {
|
||||
auto &kernel = kernels[i];
|
||||
auto curr_stream_id = AnfAlgo::GetStreamId(kernel);
|
||||
if (stream_id_map_.find(curr_stream_id) == stream_id_map_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Stream " << curr_stream_id << "has not been created.";
|
||||
}
|
||||
auto wait_stream = stream_id_map_[curr_stream_id];
|
||||
std::vector<AnfNodePtr> used_kernels;
|
||||
std::set<AnfNodePtr> visited_kernels;
|
||||
AnfAlgo::GetAllVisitedCNode(kernel, &used_kernels, &visited_kernels);
|
||||
bool found_depend = false;
|
||||
std::set<AnfNodePtr> record_nodes;
|
||||
// set events for nodes and its input: [input_node_stream, node_stream]
|
||||
for (auto &visited : used_kernels) {
|
||||
auto pre_cnode_stream_id = AnfAlgo::GetStreamId(visited);
|
||||
if (stream_id_map_.find(pre_cnode_stream_id) == stream_id_map_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Stream " << pre_cnode_stream_id << "has not been created.";
|
||||
}
|
||||
if (pre_cnode_stream_id == curr_stream_id) {
|
||||
found_depend = true;
|
||||
continue;
|
||||
}
|
||||
if (record_nodes.find(visited) == record_nodes.end()) {
|
||||
found_depend = true;
|
||||
auto record_stream = stream_id_map_[pre_cnode_stream_id];
|
||||
auto event = CreateDeviceEvent();
|
||||
event->set_wait_stream(wait_stream);
|
||||
event->set_record_stream(record_stream);
|
||||
kernel_post_run_events[visited].emplace_back([event]() { event->RecordEvent(); });
|
||||
kernel_pre_run_events[kernel].emplace_back([event]() { event->WaitEvent(); });
|
||||
}
|
||||
record_nodes.insert(visited);
|
||||
}
|
||||
// for start_node(no inputs), set event [stream_, start_node_stream]
|
||||
if (!found_depend && wait_stream != stream_) {
|
||||
auto pre_event = CreateDeviceEvent();
|
||||
pre_event->set_wait_stream(wait_stream);
|
||||
pre_event->set_record_stream(stream_);
|
||||
kernel_pre_run_events[kernel].emplace_back([pre_event]() { pre_event->RecordEvent(); });
|
||||
kernel_pre_run_events[kernel].emplace_back([pre_event]() { pre_event->WaitEvent(); });
|
||||
}
|
||||
}
|
||||
// find end node of graph by last_stream_nodes, and set event [last_node_stream, stream_]
|
||||
ProcessBoundaryEvent(kernels, &kernel_post_run_events, last_stream_nodes);
|
||||
graph_kernel_events_map_[graph.graph_id()] = std::move(kernel_events);
|
||||
}
|
||||
|
||||
std::pair<vector<std::function<void()>>, vector<std::function<void()>>> AscendKernelRuntime::GetKernelEventFuncs(
|
||||
const CNodePtr &kernel) const {
|
||||
std::map<AnfNodePtr, std::vector<std::function<void()>>> kernels_pre_event_funcs;
|
||||
|
|
|
@ -43,6 +43,7 @@ class AscendKernelRuntime : public KernelRuntime {
|
|||
bool Init() override;
|
||||
bool LoadData(const session::KernelGraph &graph) override;
|
||||
bool GenTask(const session::KernelGraph &graph);
|
||||
void GenKernelEventsForMindRT(const session::KernelGraph &graph);
|
||||
void GenKernelEvents(const session::KernelGraph &graph) override;
|
||||
std::pair<vector<std::function<void()>>, vector<std::function<void()>>> GetKernelEventFuncs(
|
||||
const CNodePtr &kernel) const;
|
||||
|
|
|
@ -349,7 +349,7 @@ void AscendDeviceContext::UpdateExecOrder(const KernelGraphPtr &graph) const {
|
|||
void AscendDeviceContext::GenKernelEvents(const NotNull<KernelGraphPtr> &root_graph) const {
|
||||
MS_LOG(INFO) << "Start GenKernelEvents for graph " << root_graph->graph_id();
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
||||
runtime_instance_->GenKernelEvents(*root_graph.get());
|
||||
runtime_instance_->GenKernelEventsForMindRT(*root_graph.get());
|
||||
MS_LOG(INFO) << "Finish!";
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue