!29769 fix event error

Merge pull request !29769 from TuDouNi/event
This commit is contained in:
i-robot 2022-02-08 10:06:33 +00:00 committed by Gitee
commit 52678993c6
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 60 additions and 1 deletions

View File

@ -877,6 +877,64 @@ void AscendKernelRuntime::GenKernelEvents(const session::KernelGraph &graph) {
graph_kernel_events_map_[graph.graph_id()] = std::move(kernel_events);
}
void AscendKernelRuntime::GenKernelEventsForMindRT(const session::KernelGraph &graph) {
auto &kernels = graph.execution_order();
if (kernels.empty() || graph_kernel_events_map_.find(graph.graph_id()) != graph_kernel_events_map_.end()) {
return;
}
std::vector<size_t> last_stream_nodes;
SetKernelModStream(kernels, &last_stream_nodes);
auto kernel_events = std::pair<std::map<AnfNodePtr, std::vector<std::function<void()>>>,
std::map<AnfNodePtr, std::vector<std::function<void()>>>>();
auto &kernel_pre_run_events = kernel_events.first;
auto &kernel_post_run_events = kernel_events.second;
for (size_t i = 0; i < kernels.size(); ++i) {
auto &kernel = kernels[i];
auto curr_stream_id = AnfAlgo::GetStreamId(kernel);
if (stream_id_map_.find(curr_stream_id) == stream_id_map_.end()) {
MS_LOG(EXCEPTION) << "Stream " << curr_stream_id << "has not been created.";
}
auto wait_stream = stream_id_map_[curr_stream_id];
std::vector<AnfNodePtr> used_kernels;
std::set<AnfNodePtr> visited_kernels;
AnfAlgo::GetAllVisitedCNode(kernel, &used_kernels, &visited_kernels);
bool found_depend = false;
std::set<AnfNodePtr> record_nodes;
// set events for nodes and its input: [input_node_stream, node_stream]
for (auto &visited : used_kernels) {
auto pre_cnode_stream_id = AnfAlgo::GetStreamId(visited);
if (stream_id_map_.find(pre_cnode_stream_id) == stream_id_map_.end()) {
MS_LOG(EXCEPTION) << "Stream " << pre_cnode_stream_id << "has not been created.";
}
if (pre_cnode_stream_id == curr_stream_id) {
found_depend = true;
continue;
}
if (record_nodes.find(visited) == record_nodes.end()) {
found_depend = true;
auto record_stream = stream_id_map_[pre_cnode_stream_id];
auto event = CreateDeviceEvent();
event->set_wait_stream(wait_stream);
event->set_record_stream(record_stream);
kernel_post_run_events[visited].emplace_back([event]() { event->RecordEvent(); });
kernel_pre_run_events[kernel].emplace_back([event]() { event->WaitEvent(); });
}
record_nodes.insert(visited);
}
// for start_node(no inputs), set event [stream_, start_node_stream]
if (!found_depend && wait_stream != stream_) {
auto pre_event = CreateDeviceEvent();
pre_event->set_wait_stream(wait_stream);
pre_event->set_record_stream(stream_);
kernel_pre_run_events[kernel].emplace_back([pre_event]() { pre_event->RecordEvent(); });
kernel_pre_run_events[kernel].emplace_back([pre_event]() { pre_event->WaitEvent(); });
}
}
// find end node of graph by last_stream_nodes, and set event [last_node_stream, stream_]
ProcessBoundaryEvent(kernels, &kernel_post_run_events, last_stream_nodes);
graph_kernel_events_map_[graph.graph_id()] = std::move(kernel_events);
}
std::pair<vector<std::function<void()>>, vector<std::function<void()>>> AscendKernelRuntime::GetKernelEventFuncs(
const CNodePtr &kernel) const {
std::map<AnfNodePtr, std::vector<std::function<void()>>> kernels_pre_event_funcs;

View File

@ -43,6 +43,7 @@ class AscendKernelRuntime : public KernelRuntime {
bool Init() override;
bool LoadData(const session::KernelGraph &graph) override;
bool GenTask(const session::KernelGraph &graph);
void GenKernelEventsForMindRT(const session::KernelGraph &graph);
void GenKernelEvents(const session::KernelGraph &graph) override;
std::pair<vector<std::function<void()>>, vector<std::function<void()>>> GetKernelEventFuncs(
const CNodePtr &kernel) const;

View File

@ -349,7 +349,7 @@ void AscendDeviceContext::UpdateExecOrder(const KernelGraphPtr &graph) const {
void AscendDeviceContext::GenKernelEvents(const NotNull<KernelGraphPtr> &root_graph) const {
MS_LOG(INFO) << "Start GenKernelEvents for graph " << root_graph->graph_id();
MS_EXCEPTION_IF_NULL(runtime_instance_);
runtime_instance_->GenKernelEvents(*root_graph.get());
runtime_instance_->GenKernelEventsForMindRT(*root_graph.get());
MS_LOG(INFO) << "Finish!";
}