event for kernel by kernel
This commit is contained in:
parent
791cd0e237
commit
8a0595305c
|
@ -106,8 +106,8 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph {
|
|||
input_nodes_ = graph.input_nodes_;
|
||||
pre_graphs_ = graph.pre_graphs_;
|
||||
post_graphs_ = graph.post_graphs_;
|
||||
allreduce_from_send_recv_pairs_ = graph.allreduce_from_send_recv_pairs_;
|
||||
allreduce_to_send_recv_pairs_ = graph.allreduce_to_send_recv_pairs_;
|
||||
send_recv_pairs_for_parallel_op_inputs_ = graph.send_recv_pairs_for_parallel_op_inputs_;
|
||||
send_recv_pairs_for_parallel_op_outputs_ = graph.send_recv_pairs_for_parallel_op_outputs_;
|
||||
size_t pre_graph_finished_count = graph.pre_graph_finished_count_;
|
||||
pre_graph_finished_count_ = pre_graph_finished_count;
|
||||
size_t post_graph_finished_count = graph.post_graph_finished_count_;
|
||||
|
@ -375,18 +375,34 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph {
|
|||
}
|
||||
// end of handle graph dependency
|
||||
|
||||
// The interface of allreduce send/recv pairs map.
|
||||
void InsertFromSendRecvPair(const CNodePtr &allreduce, const std::pair<CNodePtr, CNodePtr> &send_recv_pair) {
|
||||
allreduce_from_send_recv_pairs_[allreduce] = send_recv_pair;
|
||||
// The interface of parallel op send/recv pairs map.
|
||||
void InsertSendRecvPairForParallelOpInputs(const CNodePtr ¶llel_op,
|
||||
const std::pair<CNodePtr, CNodePtr> &send_recv_pair) {
|
||||
auto iter = send_recv_pairs_for_parallel_op_inputs_.find(parallel_op);
|
||||
if (iter == send_recv_pairs_for_parallel_op_inputs_.end()) {
|
||||
send_recv_pairs_for_parallel_op_inputs_[parallel_op] = {send_recv_pair};
|
||||
} else {
|
||||
iter->second.emplace_back(send_recv_pair);
|
||||
}
|
||||
}
|
||||
void InsertToSendRecvPair(const CNodePtr &allreduce, const std::pair<CNodePtr, CNodePtr> &send_recv_pair) {
|
||||
allreduce_to_send_recv_pairs_[allreduce] = send_recv_pair;
|
||||
|
||||
void InsertSendRecvPairForParallelOpOutputs(const CNodePtr ¶llel_op,
|
||||
const std::pair<CNodePtr, CNodePtr> &send_recv_pair) {
|
||||
auto iter = send_recv_pairs_for_parallel_op_outputs_.find(parallel_op);
|
||||
if (iter == send_recv_pairs_for_parallel_op_outputs_.end()) {
|
||||
send_recv_pairs_for_parallel_op_outputs_[parallel_op] = {send_recv_pair};
|
||||
} else {
|
||||
iter->second.emplace_back(send_recv_pair);
|
||||
}
|
||||
}
|
||||
const mindspore::HashMap<CNodePtr, std::pair<CNodePtr, CNodePtr>> &allreduce_from_send_recv_pairs() const {
|
||||
return allreduce_from_send_recv_pairs_;
|
||||
|
||||
const mindspore::HashMap<CNodePtr, std::vector<std::pair<CNodePtr, CNodePtr>>>
|
||||
&send_recv_pairs_for_parallel_op_inputs() const {
|
||||
return send_recv_pairs_for_parallel_op_inputs_;
|
||||
}
|
||||
const mindspore::HashMap<CNodePtr, std::pair<CNodePtr, CNodePtr>> &allreduce_to_send_recv_pairs() const {
|
||||
return allreduce_to_send_recv_pairs_;
|
||||
const mindspore::HashMap<CNodePtr, std::vector<std::pair<CNodePtr, CNodePtr>>>
|
||||
&send_recv_pairs_for_parallel_op_outputs() const {
|
||||
return send_recv_pairs_for_parallel_op_outputs_;
|
||||
}
|
||||
|
||||
uint32_t label_num() const { return label_num_; }
|
||||
|
@ -521,10 +537,11 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph {
|
|||
std::map<session::KernelWithIndex, session::KernelWithIndex, session::KernelWithIndexCmp> nop_node_output_map_;
|
||||
mindspore::HashMap<uint32_t, std::weak_ptr<session::KernelGraph>> pre_graphs_;
|
||||
mindspore::HashMap<uint32_t, std::weak_ptr<session::KernelGraph>> post_graphs_;
|
||||
// The send/recv pairs inserted for allreduce, the key is allreduce kernel, the first of pair is send node, the second
|
||||
// of pair is recv node.
|
||||
mindspore::HashMap<CNodePtr, std::pair<CNodePtr, CNodePtr>> allreduce_from_send_recv_pairs_;
|
||||
mindspore::HashMap<CNodePtr, std::pair<CNodePtr, CNodePtr>> allreduce_to_send_recv_pairs_;
|
||||
|
||||
// key:parallel op ptr, value:vector of <send op receive op > pairs
|
||||
mindspore::HashMap<CNodePtr, std::vector<std::pair<CNodePtr, CNodePtr>>> send_recv_pairs_for_parallel_op_inputs_;
|
||||
// key:parallel op ptr, value:vector of <send op receive op > pairs
|
||||
mindspore::HashMap<CNodePtr, std::vector<std::pair<CNodePtr, CNodePtr>>> send_recv_pairs_for_parallel_op_outputs_;
|
||||
std::atomic<size_t> pre_graph_finished_count_{0};
|
||||
std::atomic<size_t> post_graph_finished_count_{0};
|
||||
bool first_step_{true};
|
||||
|
|
|
@ -743,6 +743,29 @@ bool AscendKernelRuntime::Run(const session::KernelGraph &graph, bool is_task_si
|
|||
return ret;
|
||||
}
|
||||
|
||||
void AscendKernelRuntime::SetKernelModRtStream(const NotNull<KernelGraphPtr> &graph_ptr) {
|
||||
const auto &kernels = graph_ptr->execution_order();
|
||||
for (size_t i = 0; i < kernels.size(); ++i) {
|
||||
auto &node = kernels[i];
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(node);
|
||||
auto ascend_kernel_mod = dynamic_cast<kernel::AscendKernelMod *>(kernel_mod);
|
||||
MS_EXCEPTION_IF_NULL(ascend_kernel_mod);
|
||||
auto stream_id = AnfAlgo::GetStreamId(kernels[i]);
|
||||
auto iter = stream_id_map_.find(stream_id);
|
||||
if (iter == stream_id_map_.end()) {
|
||||
void *stream = nullptr;
|
||||
auto ret = rtStreamCreate(&stream, 0);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "create communication stream failed, ret:" << ret;
|
||||
}
|
||||
stream_id_map_[stream_id] = stream;
|
||||
ascend_kernel_mod->set_stream(stream);
|
||||
} else {
|
||||
ascend_kernel_mod->set_stream(iter->second);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AscendKernelRuntime::SetKernelModStream(const std::vector<CNodePtr> &kernels,
|
||||
std::vector<size_t> *last_stream_nodes) {
|
||||
std::map<void *, size_t> last_kernel;
|
||||
|
@ -872,91 +895,6 @@ void AscendKernelRuntime::GenKernelEvents(const session::KernelGraph &graph) {
|
|||
graph_kernel_events_map_[graph.graph_id()] = std::move(kernel_events);
|
||||
}
|
||||
|
||||
void AscendKernelRuntime::GenKernelEventsForMindRT(const session::KernelGraph &graph) {
|
||||
auto &kernels = graph.execution_order();
|
||||
if (kernels.empty() || graph_kernel_events_map_.find(graph.graph_id()) != graph_kernel_events_map_.end()) {
|
||||
return;
|
||||
}
|
||||
std::vector<size_t> last_stream_nodes;
|
||||
SetKernelModStream(kernels, &last_stream_nodes);
|
||||
auto kernel_events = std::pair<std::map<AnfNodePtr, std::vector<std::function<void()>>>,
|
||||
std::map<AnfNodePtr, std::vector<std::function<void()>>>>();
|
||||
auto &kernel_pre_run_events = kernel_events.first;
|
||||
auto &kernel_post_run_events = kernel_events.second;
|
||||
for (size_t i = 0; i < kernels.size(); ++i) {
|
||||
auto &kernel = kernels[i];
|
||||
auto curr_stream_id = AnfAlgo::GetStreamId(kernel);
|
||||
if (stream_id_map_.find(curr_stream_id) == stream_id_map_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Stream " << curr_stream_id << "has not been created.";
|
||||
}
|
||||
auto wait_stream = stream_id_map_[curr_stream_id];
|
||||
std::vector<AnfNodePtr> used_kernels;
|
||||
std::set<AnfNodePtr> visited_kernels;
|
||||
common::AnfAlgo::GetAllVisitedCNode(kernel, &used_kernels, &visited_kernels);
|
||||
bool found_depend = false;
|
||||
std::set<AnfNodePtr> record_nodes;
|
||||
// set events for nodes and its input: [input_node_stream, node_stream]
|
||||
for (auto &visited : used_kernels) {
|
||||
auto pre_cnode_stream_id = AnfAlgo::GetStreamId(visited);
|
||||
if (stream_id_map_.find(pre_cnode_stream_id) == stream_id_map_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Stream " << pre_cnode_stream_id << "has not been created.";
|
||||
}
|
||||
if (pre_cnode_stream_id == curr_stream_id) {
|
||||
found_depend = true;
|
||||
continue;
|
||||
}
|
||||
if (record_nodes.find(visited) == record_nodes.end()) {
|
||||
found_depend = true;
|
||||
auto record_stream = stream_id_map_[pre_cnode_stream_id];
|
||||
auto event = CreateDeviceEvent();
|
||||
event->set_wait_stream(wait_stream);
|
||||
event->set_record_stream(record_stream);
|
||||
kernel_post_run_events[visited].emplace_back([event]() { event->RecordEvent(); });
|
||||
kernel_pre_run_events[kernel].emplace_back([event]() { event->WaitEvent(); });
|
||||
}
|
||||
record_nodes.insert(visited);
|
||||
}
|
||||
// for start_node(no inputs), set event [stream_, start_node_stream]
|
||||
if (!found_depend && wait_stream != stream_) {
|
||||
auto pre_event = CreateDeviceEvent();
|
||||
pre_event->set_wait_stream(wait_stream);
|
||||
pre_event->set_record_stream(stream_);
|
||||
kernel_pre_run_events[kernel].emplace_back([pre_event]() { pre_event->RecordEvent(); });
|
||||
kernel_pre_run_events[kernel].emplace_back([pre_event]() { pre_event->WaitEvent(); });
|
||||
}
|
||||
}
|
||||
// find end node of graph by last_stream_nodes, and set event [last_node_stream, stream_]
|
||||
ProcessBoundaryEvent(kernels, &kernel_post_run_events, last_stream_nodes);
|
||||
graph_kernel_events_map_[graph.graph_id()] = std::move(kernel_events);
|
||||
}
|
||||
|
||||
std::pair<vector<std::function<void()>>, vector<std::function<void()>>> AscendKernelRuntime::GetKernelEventFuncs(
|
||||
const CNodePtr &kernel) const {
|
||||
std::map<AnfNodePtr, std::vector<std::function<void()>>> kernels_pre_event_funcs;
|
||||
std::map<AnfNodePtr, std::vector<std::function<void()>>> kernels_post_event_funcs;
|
||||
std::vector<std::function<void()>> kernel_pre_event_funcs;
|
||||
std::vector<std::function<void()>> kernel_post_event_funcs;
|
||||
|
||||
auto graph_id = AnfAlgo::GetGraphId(kernel.get());
|
||||
auto events_iter = graph_kernel_events_map_.find(graph_id);
|
||||
if (events_iter != graph_kernel_events_map_.end()) {
|
||||
kernels_pre_event_funcs = events_iter->second.first;
|
||||
kernels_post_event_funcs = events_iter->second.second;
|
||||
}
|
||||
|
||||
auto pre_event_funcs_iter = kernels_pre_event_funcs.find(kernel);
|
||||
if (pre_event_funcs_iter != kernels_pre_event_funcs.end()) {
|
||||
kernel_pre_event_funcs = pre_event_funcs_iter->second;
|
||||
}
|
||||
|
||||
auto post_event_funcs_iter = kernels_post_event_funcs.find(kernel);
|
||||
if (post_event_funcs_iter != kernels_post_event_funcs.end()) {
|
||||
kernel_post_event_funcs = post_event_funcs_iter->second;
|
||||
}
|
||||
|
||||
return std::make_pair(kernel_pre_event_funcs, kernel_post_event_funcs);
|
||||
}
|
||||
|
||||
void AscendKernelRuntime::ProcessBoundaryEvent(
|
||||
const std::vector<CNodePtr> &kernels, std::map<AnfNodePtr, std::vector<std::function<void()>>> *kernel_run_events,
|
||||
const std::vector<size_t> &last_stream_nodes) {
|
||||
|
|
|
@ -44,11 +44,9 @@ class AscendKernelRuntime : public KernelRuntime {
|
|||
bool Init() override;
|
||||
bool LoadData(const session::KernelGraph &graph) override;
|
||||
bool GenTask(const session::KernelGraph &graph);
|
||||
void GenKernelEventsForMindRT(const session::KernelGraph &graph);
|
||||
void GenKernelEvents(const session::KernelGraph &graph) override;
|
||||
std::pair<vector<std::function<void()>>, vector<std::function<void()>>> GetKernelEventFuncs(
|
||||
const CNodePtr &kernel) const;
|
||||
void SetKernelModStream(const std::vector<CNodePtr> &kernels, std::vector<size_t> *last_stream_nodes);
|
||||
void SetKernelModRtStream(const NotNull<KernelGraphPtr> &graph_ptr);
|
||||
void ProcessBoundaryEvent(const std::vector<CNodePtr> &kernels,
|
||||
std::map<AnfNodePtr, std::vector<std::function<void()>>> *kernel_run_events,
|
||||
const std::vector<size_t> &last_stream_nodes);
|
||||
|
|
|
@ -217,28 +217,11 @@ void AscendStreamAssign::AssignStreamForNonTaskSink(const std::vector<CNodePtr>
|
|||
if (kernels.empty()) {
|
||||
return;
|
||||
}
|
||||
if (stream_groups_.empty()) {
|
||||
stream_groups_.emplace_back(std::vector<uint32_t>{kDefaultStreamIndex});
|
||||
stream_groups_.emplace_back(std::vector<uint32_t>{kIndependentStreamIndex});
|
||||
stream_groups_.emplace_back(std::vector<uint32_t>{kWorldGroupStreamIndex});
|
||||
}
|
||||
group_stream_id_map_[kHcclWorldGroup] = kWorldGroupStreamIndex;
|
||||
for (size_t i = 0; i < kernels.size(); ++i) {
|
||||
auto &node = kernels[i];
|
||||
if (common::AnfAlgo::IsCommunicationOp(node)) {
|
||||
auto group = common::AnfAlgo::GetNodeAttr<std::string>(node, kAttrGroup);
|
||||
auto iter = group_stream_id_map_.find(group);
|
||||
if (iter == group_stream_id_map_.end()) {
|
||||
auto id = SizeToUint(group_stream_id_map_.size()) + kWorldGroupStreamIndex;
|
||||
group_stream_id_map_[group] = id;
|
||||
AnfAlgo::SetStreamId(id, node.get());
|
||||
stream_groups_.emplace_back(std::vector<uint32_t>{id});
|
||||
} else {
|
||||
auto id = iter->second;
|
||||
AnfAlgo::SetStreamId(id, node.get());
|
||||
}
|
||||
} else if (AnfAlgo::IsIndependentNode(node)) {
|
||||
AnfAlgo::SetStreamId(kIndependentStreamIndex, node.get());
|
||||
AnfAlgo::SetStreamId(kWorldGroupStreamIndex, node.get());
|
||||
} else {
|
||||
AnfAlgo::SetStreamId(kDefaultStreamIndex, node.get());
|
||||
}
|
||||
|
@ -251,6 +234,234 @@ void AscendStreamAssign::AssignStreamForNonTaskSink(const std::vector<CNodePtr>
|
|||
}
|
||||
}
|
||||
|
||||
void GenKernelIoExecInfoMap(const NotNull<KernelGraphPtr> &kernel_graph,
|
||||
mindspore::HashMap<CNodePtr, NodeIoExecInfoPtr> *kernel_io_exec_info_map) {
|
||||
auto &exec_kernels = kernel_graph->execution_order();
|
||||
for (size_t i = 0; i < exec_kernels.size(); ++i) {
|
||||
auto &process_kernel = exec_kernels[i];
|
||||
MS_EXCEPTION_IF_NULL(process_kernel);
|
||||
auto process_exec_info = std::make_shared<NodeExecInfo>();
|
||||
MS_EXCEPTION_IF_NULL(process_exec_info);
|
||||
process_exec_info->node = process_kernel;
|
||||
process_exec_info->stream_id = AnfAlgo::GetStreamId(process_kernel);
|
||||
process_exec_info->execution_order_index = i;
|
||||
auto process_io_exec_info = std::make_shared<NodeIoExecInfo>();
|
||||
MS_EXCEPTION_IF_NULL(process_io_exec_info);
|
||||
process_io_exec_info->node_exec_info = process_exec_info;
|
||||
process_io_exec_info->inputs = {};
|
||||
process_io_exec_info->outputs = {};
|
||||
(*kernel_io_exec_info_map)[process_kernel] = process_io_exec_info;
|
||||
}
|
||||
|
||||
for (auto &process_kernel : exec_kernels) {
|
||||
MS_EXCEPTION_IF_NULL(process_kernel);
|
||||
auto process_iter = kernel_io_exec_info_map->find(process_kernel);
|
||||
if (process_iter == kernel_io_exec_info_map->end()) {
|
||||
MS_LOG(ERROR) << "Can't get kernel io execution info for " << process_kernel->fullname_with_scope();
|
||||
continue;
|
||||
}
|
||||
auto process_io_exec_info = process_iter->second;
|
||||
auto process_exec_info = process_iter->second->node_exec_info;
|
||||
|
||||
auto inputs = process_kernel->inputs();
|
||||
for (size_t i = 1; i < inputs.size(); i++) {
|
||||
auto input_node = common::AnfAlgo::VisitKernelWithReturnType(inputs[i], 0).first;
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (AnfUtils::IsRealCNodeKernel(input_node)) {
|
||||
auto input_kernel = input_node->cast<CNodePtr>();
|
||||
auto iter = kernel_io_exec_info_map->find(input_kernel);
|
||||
if (iter == kernel_io_exec_info_map->end()) {
|
||||
MS_LOG(ERROR) << "Can't get kernel io execution info for " << process_kernel->fullname_with_scope()
|
||||
<< "'s input node " << input_kernel->fullname_with_scope();
|
||||
continue;
|
||||
}
|
||||
auto input_io_exec_info = iter->second;
|
||||
auto input_exec_info = iter->second->node_exec_info;
|
||||
process_io_exec_info->inputs.push_back(input_exec_info);
|
||||
input_io_exec_info->outputs.push_back(process_exec_info);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AscendStreamAssign::InsertEventsForInputs(const NotNull<KernelGraphPtr> &kernel_graph, const CNodePtr &kernel,
|
||||
const NodeIoExecInfoPtr &io_exec_info,
|
||||
mindspore::HashMap<AnfNodePtr, std::vector<CNodePtr>> *kernel_send,
|
||||
mindspore::HashMap<AnfNodePtr, std::vector<CNodePtr>> *kernel_recv) {
|
||||
auto process_stream_id = AnfAlgo::GetStreamId(kernel);
|
||||
auto input_exec_info_list = io_exec_info->inputs;
|
||||
mindspore::HashMap<uint32_t, NodeExecInfoPtr> stream_max_exec_node_map;
|
||||
|
||||
for (auto &input : input_exec_info_list) {
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
auto input_stream_id = input->stream_id;
|
||||
auto iter = stream_max_exec_node_map.find(input_stream_id);
|
||||
if (iter == stream_max_exec_node_map.end()) {
|
||||
stream_max_exec_node_map[input_stream_id] = input;
|
||||
} else {
|
||||
MS_EXCEPTION_IF_NULL(iter->second);
|
||||
if (input->execution_order_index > iter->second->execution_order_index) {
|
||||
iter->second = input;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto input_exec : stream_max_exec_node_map) {
|
||||
MS_EXCEPTION_IF_NULL(input_exec.second);
|
||||
if (input_exec.second->stream_id == process_stream_id) {
|
||||
continue;
|
||||
}
|
||||
InsertEvents(kernel_graph, kernel, input_exec.second->node, kernel_send, kernel_recv, kernel);
|
||||
}
|
||||
}
|
||||
|
||||
void AscendStreamAssign::InsertEventsForOutputs(const NotNull<KernelGraphPtr> &kernel_graph, const CNodePtr &kernel,
|
||||
const NodeIoExecInfoPtr &io_exec_info,
|
||||
mindspore::HashMap<AnfNodePtr, std::vector<CNodePtr>> *kernel_send,
|
||||
mindspore::HashMap<AnfNodePtr, std::vector<CNodePtr>> *kernel_recv) {
|
||||
auto process_stream_id = AnfAlgo::GetStreamId(kernel);
|
||||
auto output_exec_info_list = io_exec_info->outputs;
|
||||
mindspore::HashMap<uint32_t, NodeExecInfoPtr> stream_min_exec_node_map;
|
||||
for (auto &output : output_exec_info_list) {
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
auto output_stream_id = output->stream_id;
|
||||
auto iter = stream_min_exec_node_map.find(output_stream_id);
|
||||
if (iter == stream_min_exec_node_map.end()) {
|
||||
stream_min_exec_node_map[output_stream_id] = output;
|
||||
} else {
|
||||
MS_EXCEPTION_IF_NULL(iter->second);
|
||||
if (output->execution_order_index < iter->second->execution_order_index) {
|
||||
iter->second = output;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto output_exec : stream_min_exec_node_map) {
|
||||
MS_EXCEPTION_IF_NULL(output_exec.second);
|
||||
if (output_exec.second->stream_id == process_stream_id) {
|
||||
continue;
|
||||
}
|
||||
InsertEvents(kernel_graph, kernel, kernel, kernel_send, kernel_recv, output_exec.second->node);
|
||||
}
|
||||
|
||||
// parallel op has output tensor, and it didn't connect to other kernel, it's output is graph output, sync it.
|
||||
if (output_exec_info_list.empty() && (common::AnfAlgo::GetOutputTensorNum(kernel) != 0)) {
|
||||
InsertEvents(kernel_graph, kernel, kernel, kernel_send, kernel_recv, kernel_graph->output());
|
||||
}
|
||||
}
|
||||
|
||||
void AscendStreamAssign::InsertEvents(const NotNull<KernelGraphPtr> &kernel_graph, const CNodePtr ¶llel_cnode,
|
||||
const AnfNodePtr &node_before_send,
|
||||
mindspore::HashMap<AnfNodePtr, std::vector<CNodePtr>> *kernel_send,
|
||||
mindspore::HashMap<AnfNodePtr, std::vector<CNodePtr>> *kernel_recv,
|
||||
const AnfNodePtr &node_after_recv) {
|
||||
AscendStreamMng &resource_manager = AscendStreamMng::GetInstance();
|
||||
uint32_t event_id = resource_manager.ApplyNewEvent();
|
||||
auto event = resource_manager.ApplyRtEvent();
|
||||
auto send_cnode = CreateSendApplyKernel(kernel_graph, event_id, AnfAlgo::GetStreamId(node_before_send));
|
||||
common::AnfAlgo::SetNodeAttr(kAttrRecordEvent, MakeValue(reinterpret_cast<uintptr_t>(event)), send_cnode);
|
||||
auto send_iter = kernel_send->find(node_before_send);
|
||||
if (send_iter == kernel_send->end()) {
|
||||
(*kernel_send)[node_before_send] = {send_cnode};
|
||||
} else {
|
||||
send_iter->second.push_back(send_cnode);
|
||||
}
|
||||
|
||||
CNodePtr recv_cnode = CreateRecvApplyKernel(kernel_graph, event_id, AnfAlgo::GetStreamId(node_after_recv));
|
||||
common::AnfAlgo::SetNodeAttr(kAttrWaitEvent, MakeValue(reinterpret_cast<uintptr_t>(event)), recv_cnode);
|
||||
auto process_iter = kernel_recv->find(node_after_recv);
|
||||
if (process_iter == kernel_recv->end()) {
|
||||
(*kernel_recv)[node_after_recv] = {recv_cnode};
|
||||
} else {
|
||||
process_iter->second.push_back(recv_cnode);
|
||||
}
|
||||
|
||||
if (parallel_cnode == node_before_send) {
|
||||
kernel_graph->InsertSendRecvPairForParallelOpOutputs(parallel_cnode, std::make_pair(send_cnode, recv_cnode));
|
||||
MS_LOG(INFO) << "Generate send/recv for parallel op " << parallel_cnode->fullname_with_scope() << "'s output."
|
||||
<< "Send node " << send_cnode->fullname_with_scope() << " after "
|
||||
<< node_before_send->fullname_with_scope() << ", recv node " << recv_cnode->fullname_with_scope()
|
||||
<< " before " << node_after_recv->fullname_with_scope();
|
||||
} else {
|
||||
kernel_graph->InsertSendRecvPairForParallelOpInputs(parallel_cnode, std::make_pair(send_cnode, recv_cnode));
|
||||
MS_LOG(INFO) << "Generate send/recv for parallel op " << parallel_cnode->fullname_with_scope() << "'s input."
|
||||
<< "Send node " << send_cnode->fullname_with_scope() << " after "
|
||||
<< node_before_send->fullname_with_scope() << ", recv node " << recv_cnode->fullname_with_scope()
|
||||
<< " before " << node_after_recv->fullname_with_scope();
|
||||
}
|
||||
}
|
||||
|
||||
void AscendStreamAssign::GenEventsForParallelOp(const NotNull<KernelGraphPtr> &kernel_graph,
|
||||
mindspore::HashMap<AnfNodePtr, std::vector<CNodePtr>> *kernel_send,
|
||||
mindspore::HashMap<AnfNodePtr, std::vector<CNodePtr>> *kernel_recv) {
|
||||
MS_LOG(DEBUG) << "Start GenEventsForParallelOp...";
|
||||
auto exec_kernels = kernel_graph->execution_order();
|
||||
mindspore::HashMap<CNodePtr, NodeIoExecInfoPtr> kernel_io_exec_info_map;
|
||||
GenKernelIoExecInfoMap(kernel_graph, &kernel_io_exec_info_map);
|
||||
for (auto &process_kernel : exec_kernels) {
|
||||
MS_EXCEPTION_IF_NULL(process_kernel);
|
||||
auto process_stream_id = AnfAlgo::GetStreamId(process_kernel);
|
||||
if (process_stream_id == kDefaultStreamIndex) {
|
||||
continue;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Start GenEvents For ParallelOp " << process_kernel->fullname_with_scope();
|
||||
auto process_iter = kernel_io_exec_info_map.find(process_kernel);
|
||||
if (process_iter == kernel_io_exec_info_map.end()) {
|
||||
MS_LOG(ERROR) << "Can't get node io execution info for " << process_kernel->fullname_with_scope();
|
||||
continue;
|
||||
}
|
||||
auto process_io_exec_info = process_iter->second;
|
||||
InsertEventsForInputs(kernel_graph, process_kernel, process_io_exec_info, kernel_send, kernel_recv);
|
||||
InsertEventsForOutputs(kernel_graph, process_kernel, process_io_exec_info, kernel_send, kernel_recv);
|
||||
}
|
||||
MS_LOG(DEBUG) << "Finish GenEventsForParallelOp.";
|
||||
}
|
||||
|
||||
void AscendStreamAssign::UpdateEventsToExecutionOrder(
|
||||
const NotNull<KernelGraphPtr> &kernel_graph,
|
||||
const mindspore::HashMap<AnfNodePtr, std::vector<CNodePtr>> &send_after_node,
|
||||
const mindspore::HashMap<AnfNodePtr, std::vector<CNodePtr>> &recv_before_node) const {
|
||||
MS_LOG(DEBUG) << "Start UpdateEventsToExecutionOrder...";
|
||||
auto exec_kernels = kernel_graph->execution_order();
|
||||
std::vector<CNodePtr> new_exec_orders;
|
||||
for (auto &kernel : exec_kernels) {
|
||||
auto before_iter = recv_before_node.find(kernel);
|
||||
if (before_iter != recv_before_node.end()) {
|
||||
for (auto &recv : before_iter->second) {
|
||||
new_exec_orders.push_back(recv);
|
||||
}
|
||||
}
|
||||
|
||||
new_exec_orders.push_back(kernel);
|
||||
|
||||
auto after_iter = send_after_node.find(kernel);
|
||||
if (after_iter != send_after_node.end()) {
|
||||
for (auto send : after_iter->second) {
|
||||
new_exec_orders.push_back(send);
|
||||
}
|
||||
}
|
||||
}
|
||||
auto graph_output = kernel_graph->output();
|
||||
auto graph_output_iter = recv_before_node.find(graph_output);
|
||||
if (graph_output_iter != recv_before_node.end()) {
|
||||
for (auto &recv : graph_output_iter->second) {
|
||||
new_exec_orders.push_back(recv);
|
||||
}
|
||||
}
|
||||
|
||||
kernel_graph->set_execution_order(new_exec_orders);
|
||||
MS_LOG(DEBUG) << "Finish UpdateEventsToExecutionOrder.";
|
||||
}
|
||||
|
||||
void AscendStreamAssign::InsertEventForNonTaskSink(const NotNull<KernelGraphPtr> &kernel_graph) {
|
||||
mindspore::HashMap<AnfNodePtr, std::vector<CNodePtr>> kernel_send;
|
||||
mindspore::HashMap<AnfNodePtr, std::vector<CNodePtr>> kernel_recv;
|
||||
AnfAlgo::SetStreamId(kDefaultStreamIndex, kernel_graph->output().get());
|
||||
GenEventsForParallelOp(kernel_graph, &kernel_send, &kernel_recv);
|
||||
UpdateEventsToExecutionOrder(kernel_graph, kernel_send, kernel_recv);
|
||||
InsertEventForMicroBatchIndependent(kernel_graph);
|
||||
}
|
||||
|
||||
void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr) {
|
||||
if (graph_ptr->is_dynamic_shape()) {
|
||||
MS_LOG(WARNING) << "Dynamic shape do not need to assign stream.";
|
||||
|
@ -263,6 +474,7 @@ void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr)
|
|||
if (!IsTaskSink()) {
|
||||
auto kernels = graph_ptr->execution_order();
|
||||
AssignStreamForNonTaskSink(kernels);
|
||||
InsertEventForNonTaskSink(graph_ptr);
|
||||
MS_LOG(INFO) << "After finish stream assign";
|
||||
graph_ptr->PrintGraphExecuteOrder();
|
||||
PROF_END(assign_stream);
|
||||
|
@ -2535,6 +2747,7 @@ void AscendStreamAssign::InsertEventForMicroBatchIndependent(const NotNull<Kerne
|
|||
uint32_t cur_event_id = resource_manager.ApplyNewEvent();
|
||||
CNodePtr send_cnode = CreateSendApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId((cnode)));
|
||||
CNodePtr recv_cnode = CreateRecvApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId(next_gen_mask));
|
||||
graph_ptr->InsertSendRecvPairForParallelOpInputs(next_gen_mask, std::make_pair(send_cnode, recv_cnode));
|
||||
node_send_map[cnode] = send_cnode;
|
||||
node_recv_map[next_gen_mask] = recv_cnode;
|
||||
}
|
||||
|
|
|
@ -47,6 +47,21 @@ using GroupGraphMap = std::map<std::string, std::map<uint32_t, std::vector<CNode
|
|||
const uint32_t kInvalidStreamId = UINT32_MAX;
|
||||
const uint32_t kInvalidEventId = UINT32_MAX;
|
||||
enum StreamActiveKind { kInvalid = 0, kHead, kMiddle, kTail };
|
||||
|
||||
struct NodeExecInfo {
|
||||
CNodePtr node;
|
||||
uint32_t stream_id;
|
||||
size_t execution_order_index;
|
||||
};
|
||||
using NodeExecInfoPtr = std::shared_ptr<NodeExecInfo>;
|
||||
|
||||
struct NodeIoExecInfo {
|
||||
NodeExecInfoPtr node_exec_info;
|
||||
std::vector<NodeExecInfoPtr> inputs;
|
||||
std::vector<NodeExecInfoPtr> outputs;
|
||||
};
|
||||
using NodeIoExecInfoPtr = std::shared_ptr<NodeIoExecInfo>;
|
||||
|
||||
class AscendStreamAssign {
|
||||
public:
|
||||
static AscendStreamAssign &GetInstance() {
|
||||
|
@ -191,6 +206,27 @@ class AscendStreamAssign {
|
|||
|
||||
uint32_t max_stream_count_ = 0;
|
||||
uint32_t max_task_count_ = 0;
|
||||
|
||||
// insert event for kernel by kernel
|
||||
void InsertEventForNonTaskSink(const NotNull<KernelGraphPtr> &kernel_graph);
|
||||
void GenEventsForParallelOp(const NotNull<KernelGraphPtr> &kernel_graph,
|
||||
HashMap<AnfNodePtr, vector<CNodePtr>> *kernel_send,
|
||||
HashMap<AnfNodePtr, vector<CNodePtr>> *kernel_recv);
|
||||
void UpdateEventsToExecutionOrder(const NotNull<KernelGraphPtr> &kernel_graph,
|
||||
const HashMap<AnfNodePtr, vector<CNodePtr>> &recv_before_node,
|
||||
const HashMap<AnfNodePtr, vector<CNodePtr>> &send_after_node) const;
|
||||
|
||||
void InsertEventsForInputs(const NotNull<KernelGraphPtr> &kernel_graph, const CNodePtr &kernel,
|
||||
const NodeIoExecInfoPtr &io_exec_info, HashMap<AnfNodePtr, vector<CNodePtr>> *kernel_send,
|
||||
HashMap<AnfNodePtr, vector<CNodePtr>> *kernel_recv);
|
||||
|
||||
void InsertEventsForOutputs(const NotNull<KernelGraphPtr> &kernel_graph, const CNodePtr &kernel,
|
||||
const NodeIoExecInfoPtr &io_exec_info, HashMap<AnfNodePtr, vector<CNodePtr>> *kernel_send,
|
||||
HashMap<AnfNodePtr, vector<CNodePtr>> *kernel_recv);
|
||||
|
||||
void InsertEvents(const NotNull<KernelGraphPtr> &kernel_graph, const CNodePtr ¶llel_cnode,
|
||||
const AnfNodePtr &node_before_send, HashMap<mindspore::AnfNodePtr, vector<CNodePtr>> *kernel_send,
|
||||
HashMap<AnfNodePtr, vector<CNodePtr>> *kernel_recv, const AnfNodePtr &node_after_recv);
|
||||
};
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
|
|
|
@ -17,6 +17,10 @@
|
|||
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_STREAM_MANAGER_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_STREAM_MANAGER_H_
|
||||
|
||||
#include <memory>
|
||||
#include "utils/hash_map.h"
|
||||
#include "runtime/event.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace ascend {
|
||||
|
@ -36,6 +40,16 @@ class AscendStreamMng {
|
|||
|
||||
uint32_t ApplyNewEvent() { return cur_event_num_++; }
|
||||
|
||||
rtEvent_t ApplyRtEvent() {
|
||||
auto rt_resource = std::make_shared<rtEvent_t>();
|
||||
auto ret = rtEventCreate(rt_resource.get());
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "rtEventCreate failed, ret:" << ret;
|
||||
*rt_resource = nullptr;
|
||||
}
|
||||
return *rt_resource;
|
||||
}
|
||||
|
||||
void DeleteEvent() {
|
||||
if (!cur_event_num_) {
|
||||
MS_LOG(WARNING) << "total event num is 0, no event to delete";
|
||||
|
|
|
@ -367,13 +367,6 @@ void AscendDeviceContext::UpdateExecOrder(const KernelGraphPtr &graph) const {
|
|||
node_atomics_.clear();
|
||||
}
|
||||
|
||||
void AscendDeviceContext::GenKernelEvents(const NotNull<KernelGraphPtr> &root_graph) const {
|
||||
MS_LOG(INFO) << "Start GenKernelEvents for graph " << root_graph->graph_id();
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
||||
runtime_instance_->GenKernelEventsForMindRT(*root_graph.get());
|
||||
MS_LOG(INFO) << "Finish!";
|
||||
}
|
||||
|
||||
void AscendDeviceContext::SetAtomicCleanToNodes(const KernelGraphPtr &graph,
|
||||
const std::map<CNodePtr, std::vector<CNodePtr>> &atomics_node) const {
|
||||
// don't clear node_atomics_ in the end, since atomic_clean_nodes_ in kernel.h is weakptr
|
||||
|
@ -422,6 +415,9 @@ void AscendDeviceContext::PreprocessBeforeRunGraph(const KernelGraphPtr &graph)
|
|||
} else {
|
||||
PreprocessBeforeRunSingleOpGraph(graph);
|
||||
AscendStreamAssign::GetInstance().AssignStream(NOT_NULL(graph));
|
||||
CreateKernel(graph->execution_order());
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
||||
runtime_instance_->SetKernelModRtStream(NOT_NULL(graph));
|
||||
}
|
||||
} catch (const std::exception &e) {
|
||||
ReportErrorMessage();
|
||||
|
@ -825,7 +821,7 @@ void *AscendDeviceContext::GetKernelStream(const CNodePtr &node) const {
|
|||
auto stream = kernel_mod->stream();
|
||||
if (stream == nullptr) {
|
||||
stream = compute_stream_;
|
||||
MS_LOG(INFO) << "Assign default compute stream for node " << node->fullname_with_scope();
|
||||
MS_LOG(ERROR) << "Assign default compute stream for node " << node->fullname_with_scope();
|
||||
}
|
||||
return stream;
|
||||
}
|
||||
|
@ -945,7 +941,7 @@ bool AscendDeviceContext::LaunchAtomicClean(const CNodePtr &node, const std::vec
|
|||
// Launch Atomic Node
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(atomic_node);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
return kernel_mod->Launch(atomic_inputs, {}, {}, compute_stream_);
|
||||
return kernel_mod->Launch(atomic_inputs, {}, {}, GetKernelStream(node));
|
||||
}
|
||||
|
||||
void AscendDeviceContext::InsertEventBeforeRunTask(const KernelGraphPtr &graph) const {
|
||||
|
|
|
@ -140,7 +140,6 @@ class AscendDeviceContext : public DeviceContext {
|
|||
static bool IsGraphMode();
|
||||
bool PySyncRuning() const;
|
||||
bool MemoryCopyAsync(const CNodePtr &node, const vector<AddressPtr> &inputs, const vector<AddressPtr> &outputs) const;
|
||||
void GenKernelEvents(const NotNull<KernelGraphPtr> &root_graph) const;
|
||||
void InsertEventBeforeRunTask(const KernelGraphPtr &graph) const;
|
||||
void SetAtomicCleanToNodes(const KernelGraphPtr &graph,
|
||||
const std::map<CNodePtr, std::vector<CNodePtr>> &atomics_node) const;
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "runtime/stream.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "plugin/device/ascend/hal/device/ge_runtime/task_info.h"
|
||||
#include "plugin/device/ascend/hal/device/ascend_stream_manager.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
|
||||
|
@ -38,18 +39,28 @@ bool RecvKernel::Init(const AnfNodePtr &anf_node) {
|
|||
MS_LOG(EXCEPTION) << "RecvKernel has no attr kAttrEventId";
|
||||
}
|
||||
event_id_ = GetValue<uint32_t>(primitive->GetAttr(kAttrEventId));
|
||||
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrWaitEvent, anf_node->cast<CNodePtr>())) {
|
||||
event_ = reinterpret_cast<rtEvent_t>(GetValue<uintptr_t>(primitive->GetAttr(kAttrWaitEvent)));
|
||||
}
|
||||
MS_LOG(INFO) << "recv op event_id_:" << event_id_;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool RecvKernel::Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &, void *stream_ptr) {
|
||||
rtEvent_t stream_event{};
|
||||
auto status = rtStreamWaitEvent(stream_ptr, stream_event);
|
||||
MS_EXCEPTION_IF_NULL(event_);
|
||||
MS_EXCEPTION_IF_NULL(stream_ptr);
|
||||
auto status = rtStreamWaitEvent(stream_ptr, event_);
|
||||
if (status != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Recv rtStreamWaitEvent failed!";
|
||||
return false;
|
||||
}
|
||||
|
||||
status = rtEventReset(event_, stream_ptr);
|
||||
if (status != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "rtEventReset failed, ret:" << status;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -37,6 +37,7 @@ class RecvKernel : public RtKernel {
|
|||
|
||||
private:
|
||||
uint32_t event_id_;
|
||||
rtEvent_t event_;
|
||||
};
|
||||
|
||||
MS_REG_RTKERNEL(streamrecv, RecvKernel);
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "plugin/device/ascend/kernel/rts/send.h"
|
||||
#include "runtime/event.h"
|
||||
#include "plugin/device/ascend/hal/device/ge_runtime/task_info.h"
|
||||
#include "plugin/device/ascend/hal/device/ascend_stream_manager.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
|
||||
|
@ -37,14 +38,19 @@ bool SendKernel::Init(const AnfNodePtr &anf_node) {
|
|||
MS_LOG(EXCEPTION) << "SendKernel has no attr kAttrEventId";
|
||||
}
|
||||
event_id_ = GetValue<uint32_t>(primitive->GetAttr(kAttrEventId));
|
||||
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrRecordEvent, anf_node->cast<CNodePtr>())) {
|
||||
event_ = reinterpret_cast<rtEvent_t>(GetValue<uintptr_t>(primitive->GetAttr(kAttrRecordEvent)));
|
||||
}
|
||||
MS_LOG(INFO) << "send op event id:" << event_id_;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SendKernel::Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &, void *stream_ptr) {
|
||||
rtEvent_t event{};
|
||||
rtError_t status = rtEventRecord(event, stream_ptr);
|
||||
MS_EXCEPTION_IF_NULL(event_);
|
||||
MS_EXCEPTION_IF_NULL(stream_ptr);
|
||||
rtError_t status = rtEventRecord(event_, stream_ptr);
|
||||
if (status != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Send op rtEventRecord failed!";
|
||||
return false;
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <vector>
|
||||
#include "plugin/device/ascend/kernel/rts/rt_kernel.h"
|
||||
#include "plugin/device/ascend/kernel/rts/rt_kernel_info.h"
|
||||
#include "plugin/device/ascend/hal/device/ascend_event.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -35,6 +36,7 @@ class SendKernel : public RtKernel {
|
|||
|
||||
private:
|
||||
uint32_t event_id_;
|
||||
rtEvent_t event_;
|
||||
};
|
||||
|
||||
MS_REG_RTKERNEL(streamsend, SendKernel);
|
||||
|
|
|
@ -222,11 +222,11 @@ void CacheSendRecvCNodesForAllReduce(const std::shared_ptr<session::KernelGraph>
|
|||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
std::pair<CNodePtr, CNodePtr> send_recv_nodes(send_node, recv_node);
|
||||
if (common::AnfAlgo::GetCNodeName(mock_send_node) == kAllReduceOpName) {
|
||||
kernel_graph->InsertToSendRecvPair(mock_send_node, send_recv_nodes);
|
||||
kernel_graph->InsertSendRecvPairForParallelOpOutputs(mock_send_node, send_recv_nodes);
|
||||
}
|
||||
|
||||
if (common::AnfAlgo::GetCNodeName(mock_recv_node) == kAllReduceOpName) {
|
||||
kernel_graph->InsertFromSendRecvPair(mock_recv_node, send_recv_nodes);
|
||||
kernel_graph->InsertSendRecvPairForParallelOpInputs(mock_recv_node, send_recv_nodes);
|
||||
}
|
||||
}
|
||||
} // namespace gpu
|
||||
|
|
|
@ -1523,68 +1523,82 @@ void GraphScheduler::LinkControlArrowBySkippedNode(AbstractActor *to_actor, cons
|
|||
|
||||
void GraphScheduler::LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
for (auto &from_iter : graph->allreduce_from_send_recv_pairs()) {
|
||||
auto to_allreduce_node = from_iter.first;
|
||||
auto from_send_node = from_iter.second.first;
|
||||
auto from_recv_node = from_iter.second.second;
|
||||
MS_EXCEPTION_IF_NULL(to_allreduce_node);
|
||||
MS_EXCEPTION_IF_NULL(from_send_node);
|
||||
MS_EXCEPTION_IF_NULL(from_recv_node);
|
||||
MS_LOG(INFO) << "Link control arrow for to_allreduce_node: " << to_allreduce_node->fullname_with_scope();
|
||||
auto to_allreduce_actor = FetchActor(to_allreduce_node->fullname_with_scope());
|
||||
auto from_send_actor = FetchActor(from_send_node->fullname_with_scope());
|
||||
auto from_recv_actor = FetchActor(from_recv_node->fullname_with_scope());
|
||||
MS_EXCEPTION_IF_NULL(to_allreduce_actor);
|
||||
MS_EXCEPTION_IF_NULL(from_send_actor);
|
||||
MS_EXCEPTION_IF_NULL(from_recv_actor);
|
||||
for (auto &from_iter : graph->send_recv_pairs_for_parallel_op_inputs()) {
|
||||
auto parallel_node = from_iter.first;
|
||||
for (auto pair : from_iter.second) {
|
||||
auto send_node = pair.first;
|
||||
auto recv_node = pair.second;
|
||||
MS_EXCEPTION_IF_NULL(parallel_node);
|
||||
MS_EXCEPTION_IF_NULL(send_node);
|
||||
MS_EXCEPTION_IF_NULL(recv_node);
|
||||
MS_LOG(INFO) << "Link control arrow for parallel node input: " << parallel_node->fullname_with_scope();
|
||||
auto parallel_actor = FetchActor(parallel_node->fullname_with_scope());
|
||||
auto send_actor = FetchActor(send_node->fullname_with_scope());
|
||||
auto recv_actor = FetchActor(recv_node->fullname_with_scope());
|
||||
MS_EXCEPTION_IF_NULL(parallel_actor);
|
||||
MS_EXCEPTION_IF_NULL(send_actor);
|
||||
MS_EXCEPTION_IF_NULL(recv_actor);
|
||||
|
||||
// inputs of to_allreduce_actor --> from_send_actor
|
||||
for (auto &input_aid : to_allreduce_actor->input_data_arrow_aids_) {
|
||||
auto input_actor = dynamic_cast<KernelActor *>(FetchActor(input_aid.Name()));
|
||||
if (input_actor != nullptr) {
|
||||
SchedulerHelper::AddControlArrow(input_actor, from_send_actor);
|
||||
// inputs of to_allreduce_actor --> from_send_actor
|
||||
for (auto &input_aid : parallel_actor->input_data_arrow_aids_) {
|
||||
auto input_actor = dynamic_cast<KernelActor *>(FetchActor(input_aid.Name()));
|
||||
if (input_actor != nullptr) {
|
||||
SchedulerHelper::AddControlArrow(input_actor, send_actor);
|
||||
}
|
||||
}
|
||||
// from_send_actor --> from_recv_actor
|
||||
SchedulerHelper::AddControlArrow(send_actor, recv_actor);
|
||||
// from_recv_actor --> to_allreduce_actor
|
||||
SchedulerHelper::AddControlArrow(recv_actor, parallel_actor);
|
||||
}
|
||||
// from_send_actor --> from_recv_actor
|
||||
SchedulerHelper::AddControlArrow(from_send_actor, from_recv_actor);
|
||||
// from_recv_actor --> to_allreduce_actor
|
||||
SchedulerHelper::AddControlArrow(from_recv_actor, to_allreduce_actor);
|
||||
}
|
||||
|
||||
for (auto &to_iter : graph->allreduce_to_send_recv_pairs()) {
|
||||
auto from_allreduce_node = to_iter.first;
|
||||
auto to_send_node = to_iter.second.first;
|
||||
auto to_recv_node = to_iter.second.second;
|
||||
MS_EXCEPTION_IF_NULL(from_allreduce_node);
|
||||
MS_EXCEPTION_IF_NULL(to_send_node);
|
||||
MS_EXCEPTION_IF_NULL(to_recv_node);
|
||||
MS_LOG(INFO) << "Link control arrow for from_allreduce_node: " << from_allreduce_node->fullname_with_scope();
|
||||
auto from_allreduce_actor = FetchActor(from_allreduce_node->fullname_with_scope());
|
||||
auto to_send_actor = FetchActor(to_send_node->fullname_with_scope());
|
||||
auto to_recv_actor = dynamic_cast<KernelActor *>(FetchActor(to_recv_node->fullname_with_scope()));
|
||||
MS_EXCEPTION_IF_NULL(from_allreduce_actor);
|
||||
MS_EXCEPTION_IF_NULL(to_send_actor);
|
||||
MS_EXCEPTION_IF_NULL(to_recv_actor);
|
||||
for (auto &to_iter : graph->send_recv_pairs_for_parallel_op_outputs()) {
|
||||
auto parallel_node = to_iter.first;
|
||||
for (auto pair : to_iter.second) {
|
||||
auto send_node = pair.first;
|
||||
auto recv_node = pair.second;
|
||||
MS_EXCEPTION_IF_NULL(parallel_node);
|
||||
MS_EXCEPTION_IF_NULL(send_node);
|
||||
MS_EXCEPTION_IF_NULL(recv_node);
|
||||
MS_LOG(INFO) << "Link control arrow for parallel node output: " << parallel_node->fullname_with_scope();
|
||||
auto parallel_actor = FetchActor(parallel_node->fullname_with_scope());
|
||||
auto send_actor = FetchActor(send_node->fullname_with_scope());
|
||||
auto recv_actor = dynamic_cast<KernelActor *>(FetchActor(recv_node->fullname_with_scope()));
|
||||
MS_EXCEPTION_IF_NULL(parallel_actor);
|
||||
MS_EXCEPTION_IF_NULL(send_actor);
|
||||
MS_EXCEPTION_IF_NULL(recv_actor);
|
||||
|
||||
// from_allreduce_actor --> to_send_actor
|
||||
SchedulerHelper::AddControlArrow(from_allreduce_actor, to_send_actor);
|
||||
// to_send_actor --> to_recv_actor
|
||||
SchedulerHelper::AddControlArrow(to_send_actor, to_recv_actor);
|
||||
// to_recv_actor --> outputs of from_allreduce_actor
|
||||
for (auto &output_data_arrow : from_allreduce_actor->output_data_arrows_) {
|
||||
auto output_actor = FetchActor(output_data_arrow->to_op_id_.Name());
|
||||
if (output_actor != nullptr) {
|
||||
SchedulerHelper::AddControlArrow(to_recv_actor, output_actor);
|
||||
// from_allreduce_actor --> to_send_actor
|
||||
SchedulerHelper::AddControlArrow(parallel_actor, send_actor);
|
||||
// to_send_actor --> to_recv_actor
|
||||
SchedulerHelper::AddControlArrow(send_actor, recv_actor);
|
||||
// to_recv_actor --> outputs of from_allreduce_actor
|
||||
for (auto &output_data_arrow : parallel_actor->output_data_arrows_) {
|
||||
auto output_actor = FetchActor(output_data_arrow->to_op_id_.Name());
|
||||
if (output_actor != nullptr) {
|
||||
SchedulerHelper::AddControlArrow(recv_actor, output_actor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// In the scene of allreduce op and computing op parallel multi stream, the input memory of allreduce can be
|
||||
// reused only when the recv node runs finished, which is expressed by the reference count increased.
|
||||
for (size_t i = 0; i < common::AnfAlgo::GetInputTensorNum(from_allreduce_node); ++i) {
|
||||
auto device_tensor = AnfAlgo::GetPrevNodeMutableOutputAddr(from_allreduce_node, i, false);
|
||||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
UpdateRefCount(device_tensor.get());
|
||||
(void)to_recv_actor->external_reference_tensors_.emplace_back(device_tensor.get());
|
||||
// In the scene of allreduce op and computing op parallel multi stream, the input memory of allreduce can be
|
||||
// reused only when the recv node runs finished, which is expressed by the reference count increased.
|
||||
for (size_t i = 0; i < common::AnfAlgo::GetInputTensorNum(parallel_node); ++i) {
|
||||
auto device_tensor = AnfAlgo::GetPrevNodeMutableOutputAddr(parallel_node, i, false);
|
||||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
UpdateRefCount(device_tensor.get());
|
||||
(void)recv_actor->external_reference_tensors_.emplace_back(device_tensor.get());
|
||||
}
|
||||
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(parallel_node);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
auto workspace_num = kernel_mod->GetWorkspaceSizeList().size();
|
||||
for (size_t i = 0; i < workspace_num; ++i) {
|
||||
auto device_tensor = AnfAlgo::GetMutableWorkspaceAddr(parallel_node, i);
|
||||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
UpdateRefCount(device_tensor.get());
|
||||
(void)recv_actor->external_reference_tensors_.emplace_back(device_tensor.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue