forked from mindspore-Ecosystem/mindspore
!10391 enable loop sink when no getnext in execution orders
From: @laiyongqiang Reviewed-by: Signed-off-by:
This commit is contained in:
commit
6d51fc558f
|
@ -762,6 +762,39 @@ bool AscendStreamAssign::IsProcessedStream(uint32_t stream_id) {
|
|||
return false;
|
||||
}
|
||||
|
||||
bool AscendStreamAssign::IsAllOutGraphOut(const KernelGraphPtr &graph, const CNodePtr &cnode) {
|
||||
auto cnode_out_num = AnfAlgo::GetOutputTensorNum(cnode);
|
||||
auto nodes = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem});
|
||||
std::set<int> output_index_set;
|
||||
// Assign Communicate Op Memory firstly.
|
||||
for (const auto &node : nodes) {
|
||||
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
|
||||
MS_EXCEPTION_IF_NULL(item_with_index.first);
|
||||
if (!item_with_index.first->isa<CNode>() || !AnfAlgo::IsRealKernel(item_with_index.first)) {
|
||||
continue;
|
||||
}
|
||||
if (item_with_index.first == cnode) {
|
||||
output_index_set.insert(item_with_index.second);
|
||||
}
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Node " << cnode->fullname_with_scope() << " has " << cnode_out_num
|
||||
<< " outputs, in graph output num:" << output_index_set.size();
|
||||
return cnode_out_num == output_index_set.size();
|
||||
}
|
||||
|
||||
vector<CNodePtr>::iterator AscendStreamAssign::FindGraphEnd(vector<CNodePtr>::iterator begin,
|
||||
vector<CNodePtr>::iterator end) {
|
||||
while (begin != end) {
|
||||
if (AnfAlgo::HasNodeAttr(kAttrFpBpEnd, *begin)) {
|
||||
MS_LOG(INFO) << "FpBp end op is " << (*begin)->fullname_with_scope();
|
||||
return begin;
|
||||
}
|
||||
++begin;
|
||||
}
|
||||
return end;
|
||||
}
|
||||
|
||||
// section5
|
||||
void AscendStreamAssign::InsertEventForHcomParallel(const NotNull<KernelGraphPtr> &graph_ptr) {
|
||||
MS_LOG(INFO) << "Start";
|
||||
|
@ -780,15 +813,23 @@ void AscendStreamAssign::InsertEventCommonDependHcom(const NotNull<KernelGraphPt
|
|||
while (it != cnodes.end()) {
|
||||
MS_EXCEPTION_IF_NULL(*it);
|
||||
if (IsHcom(*it)) {
|
||||
auto cur_hcom_node = *it;
|
||||
CNodePtr send_cnode_ptr = CreateSendApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId(*it));
|
||||
it = cnodes.insert(it + 1, send_cnode_ptr);
|
||||
|
||||
auto target = FindTargetOp(it, cnodes.end(), *(it - 1), true);
|
||||
auto target = FindTargetOp(it, cnodes.end(), cur_hcom_node, true);
|
||||
if (target == cnodes.end()) {
|
||||
MS_LOG(WARNING) << "Hcom node:" << (*(it - 1))->fullname_with_scope()
|
||||
<< ", can't find target for insert recv op, no insert send/recv";
|
||||
it = cnodes.erase(it);
|
||||
continue;
|
||||
if (IsAllOutGraphOut(graph_ptr, cur_hcom_node)) {
|
||||
// if hcom's all output is graph output, we need to insert send/recv to fpbp end in data sink mode
|
||||
target = FindGraphEnd(it, cnodes.end());
|
||||
}
|
||||
|
||||
if (target == cnodes.end()) {
|
||||
MS_LOG(WARNING) << "Hcom node:" << (*(it - 1))->fullname_with_scope()
|
||||
<< ", can't find target for insert recv op, no insert send/recv";
|
||||
it = cnodes.erase(it);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// deal recv op
|
||||
|
@ -824,7 +865,7 @@ void AscendStreamAssign::InsertEventHcomDependCommonBak(const NotNull<KernelGrap
|
|||
continue;
|
||||
}
|
||||
|
||||
// get the input which located in the lastr exe orders
|
||||
// get the input which located in the last exe orders
|
||||
vector<CNodePtr> inputs_cnode = GetLastInputCnode(graph_ptr, cur_cnode_ptr);
|
||||
if (inputs_cnode.empty()) {
|
||||
cnodes.emplace_back(cur_cnode_ptr);
|
||||
|
|
|
@ -212,6 +212,8 @@ class AscendStreamAssign {
|
|||
std::map<CNodePtr, CNodePtr> event_map_{};
|
||||
std::set<uint32_t> middle_active_streams_{};
|
||||
// new policy end
|
||||
bool IsAllOutGraphOut(const KernelGraphPtr &graph, const CNodePtr &cnode);
|
||||
vector<CNodePtr>::iterator FindGraphEnd(vector<CNodePtr>::iterator begin, vector<CNodePtr>::iterator end);
|
||||
};
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
@ -104,7 +105,18 @@ CNodePtr KernelAdjust::CreateRecvApplyKernel(const std::shared_ptr<session::Kern
|
|||
return recv_node_ptr;
|
||||
}
|
||||
|
||||
bool KernelAdjust::ExitIndependent(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
|
||||
bool KernelAdjust::ExistGetNext(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
|
||||
const std::vector<CNodePtr> &cnode_list = kernel_graph_ptr->execution_order();
|
||||
for (const auto &cnode : cnode_list) {
|
||||
if (AnfAlgo::GetCNodeName(cnode) == kGetNextOpName) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool KernelAdjust::ExistIndependent(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
|
||||
const auto &exe_orders = kernel_graph_ptr->execution_order();
|
||||
for (const auto &node : exe_orders) {
|
||||
|
@ -128,8 +140,13 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph>
|
|||
MS_LOG(INFO) << "KernelGraph:" << kernel_graph_ptr->graph_id() << " is dynamic shape, skip InsertSwitchLoop";
|
||||
return;
|
||||
}
|
||||
bool eos_mode = ConfigManager::GetInstance().iter_num() == INT32_MAX;
|
||||
ReorderGetNext(kernel_graph_ptr);
|
||||
bool exist_getnext = ExistGetNext(kernel_graph_ptr);
|
||||
bool eos_mode = ConfigManager::GetInstance().iter_num() == INT32_MAX && exist_getnext;
|
||||
MS_LOG(INFO) << "GetNext exist:" << exist_getnext << " End of Sequence mode:" << eos_mode
|
||||
<< " iter num:" << ConfigManager::GetInstance().iter_num();
|
||||
if (exist_getnext) {
|
||||
ReorderGetNext(kernel_graph_ptr);
|
||||
}
|
||||
std::map<std::string, mindspore::ParameterPtr> switch_loop_input;
|
||||
CreateSwitchOpParameters(kernel_graph_ptr, &switch_loop_input);
|
||||
|
||||
|
@ -159,84 +176,96 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph>
|
|||
std::vector<uint32_t> getnext_active_streams;
|
||||
std::vector<uint32_t> fpbp_active_streams;
|
||||
CNodePtr getnext_cnode;
|
||||
uint32_t getnext_switch_stream_id = UINT32_MAX;
|
||||
uint32_t fpbp_start_event_id = UINT32_MAX;
|
||||
uint32_t eos_start_event_id = UINT32_MAX;
|
||||
uint32_t eos_done_event_id = UINT32_MAX;
|
||||
size_t i = 0;
|
||||
|
||||
// getnext loop process
|
||||
// getnext loop stream switch op
|
||||
CNodePtr getnext_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input, kGetNextStreamSwitch);
|
||||
MS_EXCEPTION_IF_NULL(getnext_switch_app);
|
||||
uint32_t getnext_switch_stream_id = resource_manager.ApplyNewStream();
|
||||
AnfAlgo::SetStreamId(getnext_switch_stream_id, getnext_switch_app.get());
|
||||
exec_order.push_back(getnext_switch_app);
|
||||
if (exist_getnext) {
|
||||
// getnext loop stream switch op
|
||||
getnext_switch_stream_id = resource_manager.ApplyNewStream();
|
||||
uint32_t getnext_stream_id = resource_manager.ApplyNewStream();
|
||||
CNodePtr getnext_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input, kGetNextStreamSwitch);
|
||||
MS_EXCEPTION_IF_NULL(getnext_switch_app);
|
||||
AnfAlgo::SetStreamId(getnext_switch_stream_id, getnext_switch_app.get());
|
||||
// update getnext loop stream switch true_branch_stream attr
|
||||
AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue<bool>(true), getnext_switch_app);
|
||||
AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(getnext_stream_id), getnext_switch_app);
|
||||
AnfAlgo::SetNodeAttr(kAttrStreamSwitchKind, MakeValue<uint32_t>(kGetNextStreamSwitch), getnext_switch_app);
|
||||
exec_order.push_back(getnext_switch_app);
|
||||
MS_LOG(INFO) << "GetNext loop insert Stream Switch " << getnext_switch_app->fullname_with_scope();
|
||||
|
||||
// getnext op
|
||||
uint32_t getnext_stream_id = resource_manager.ApplyNewStream();
|
||||
size_t i = 0;
|
||||
for (; i < orders.size(); i++) {
|
||||
auto node = orders[i];
|
||||
exec_order.push_back(node);
|
||||
AnfAlgo::SetStreamId(getnext_stream_id, exec_order[exec_order.size() - 1].get());
|
||||
if (AnfAlgo::GetCNodeName(node) == kGetNextOpName) {
|
||||
getnext_cnode = node;
|
||||
break;
|
||||
// getnext op
|
||||
for (; i < orders.size(); i++) {
|
||||
auto node = orders[i];
|
||||
exec_order.push_back(node);
|
||||
AnfAlgo::SetStreamId(getnext_stream_id, exec_order[exec_order.size() - 1].get());
|
||||
if (AnfAlgo::GetCNodeName(node) == kGetNextOpName) {
|
||||
getnext_cnode = node;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// getnext loop fpbp start send
|
||||
fpbp_start_event_id = resource_manager.ApplyNewEvent();
|
||||
CNodePtr fpbp_start_send = CreateSendApplyKernel(kernel_graph_ptr, fpbp_start_event_id);
|
||||
AnfAlgo::SetStreamId(getnext_stream_id, fpbp_start_send.get());
|
||||
exec_order.push_back(fpbp_start_send);
|
||||
MS_LOG(INFO) << "GetNext loop insert FpBp start Send " << fpbp_start_send->fullname_with_scope();
|
||||
|
||||
if (eos_mode) {
|
||||
// getnext loop eos start send
|
||||
eos_start_event_id = resource_manager.ApplyNewEvent();
|
||||
CNodePtr eos_start_send = CreateSendApplyKernel(kernel_graph_ptr, eos_start_event_id);
|
||||
AnfAlgo::SetStreamId(getnext_stream_id, eos_start_send.get());
|
||||
exec_order.push_back(eos_start_send);
|
||||
MS_LOG(INFO) << "GetNext loop insert EoS start Send " << eos_start_send->fullname_with_scope();
|
||||
}
|
||||
}
|
||||
|
||||
// update getnext loop stream switch true_branch_stream attr
|
||||
AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue<bool>(true), getnext_switch_app);
|
||||
AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(getnext_stream_id), getnext_switch_app);
|
||||
AnfAlgo::SetNodeAttr(kAttrStreamSwitchKind, MakeValue<uint32_t>(kGetNextStreamSwitch), getnext_switch_app);
|
||||
|
||||
// getnext loop fpbp start send
|
||||
uint32_t fpbp_start_event_id = resource_manager.ApplyNewEvent();
|
||||
CNodePtr fpbp_start_send = CreateSendApplyKernel(kernel_graph_ptr, fpbp_start_event_id);
|
||||
AnfAlgo::SetStreamId(getnext_stream_id, fpbp_start_send.get());
|
||||
exec_order.push_back(fpbp_start_send);
|
||||
|
||||
// End Of Sequence loop process
|
||||
if (eos_mode) {
|
||||
// getnext loop eos start send
|
||||
uint32_t eos_start_event_id = resource_manager.ApplyNewEvent();
|
||||
CNodePtr eos_start_send = CreateSendApplyKernel(kernel_graph_ptr, eos_start_event_id);
|
||||
AnfAlgo::SetStreamId(getnext_stream_id, eos_start_send.get());
|
||||
exec_order.push_back(eos_start_send);
|
||||
|
||||
// End Of Sequence loop process
|
||||
// eos loop stream switch
|
||||
uint32_t eos_switch_stream_id = resource_manager.ApplyNewStream();
|
||||
uint32_t eos_stream_id = resource_manager.ApplyNewStream();
|
||||
CNodePtr eos_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input, kEosStreamSwitch);
|
||||
MS_EXCEPTION_IF_NULL(eos_switch_app);
|
||||
uint32_t eos_switch_stream_id = resource_manager.ApplyNewStream();
|
||||
AnfAlgo::SetStreamId(eos_switch_stream_id, eos_switch_app.get());
|
||||
AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue<bool>(true), eos_switch_app);
|
||||
exec_order.push_back(eos_switch_app);
|
||||
|
||||
// eos loop eos start recv
|
||||
CNodePtr eos_start_recv = CreateRecvApplyKernel(kernel_graph_ptr, eos_start_event_id);
|
||||
uint32_t eos_stream_id = resource_manager.ApplyNewStream();
|
||||
AnfAlgo::SetStreamId(eos_stream_id, eos_start_recv.get());
|
||||
exec_order.push_back(eos_start_recv);
|
||||
|
||||
// update eos loop stream switch true_branch_stream attr
|
||||
AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(eos_stream_id), eos_switch_app);
|
||||
AnfAlgo::SetNodeAttr(kAttrStreamSwitchKind, MakeValue<uint32_t>(kEosStreamSwitch), eos_switch_app);
|
||||
exec_order.push_back(eos_switch_app);
|
||||
MS_LOG(INFO) << "EoS loop insert Stream Switch " << eos_switch_app->fullname_with_scope();
|
||||
|
||||
// eos loop eos start recv
|
||||
CNodePtr eos_start_recv = CreateRecvApplyKernel(kernel_graph_ptr, eos_start_event_id);
|
||||
AnfAlgo::SetStreamId(eos_stream_id, eos_start_recv.get());
|
||||
exec_order.push_back(eos_start_recv);
|
||||
MS_LOG(INFO) << "EoS loop insert EoS Recv " << eos_start_recv->fullname_with_scope();
|
||||
|
||||
// EndOfSequence op
|
||||
CNodePtr end_of_sequence_op = CreateEndOfSequenceOP(kernel_graph_ptr, getnext_cnode);
|
||||
MS_EXCEPTION_IF_NULL(end_of_sequence_op);
|
||||
AnfAlgo::SetStreamId(eos_stream_id, end_of_sequence_op.get());
|
||||
exec_order.push_back(end_of_sequence_op);
|
||||
MS_LOG(INFO) << "EoS loop insert Eos Op " << end_of_sequence_op->fullname_with_scope();
|
||||
|
||||
// eos loop eos done send
|
||||
eos_done_event_id = resource_manager.ApplyNewEvent();
|
||||
CNodePtr eos_done_send = CreateSendApplyKernel(kernel_graph_ptr, eos_done_event_id);
|
||||
AnfAlgo::SetStreamId(eos_stream_id, eos_done_send.get());
|
||||
exec_order.push_back(eos_done_send);
|
||||
MS_LOG(INFO) << "EoS loop insert EoS done Send " << eos_done_send->fullname_with_scope();
|
||||
|
||||
// eos loop stream active
|
||||
fpbp_active_streams.push_back(eos_switch_stream_id);
|
||||
}
|
||||
|
||||
bool exit_independent = ExitIndependent(kernel_graph_ptr);
|
||||
if (exit_independent) {
|
||||
bool exist_independent = ExistIndependent(kernel_graph_ptr);
|
||||
if (exist_independent) {
|
||||
// Independet parallel
|
||||
CNodePtr independent_switch_app =
|
||||
CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input, kIndependentStreamSwitch);
|
||||
|
@ -246,68 +275,80 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph>
|
|||
AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue<bool>(true), independent_switch_app);
|
||||
AnfAlgo::SetNodeAttr(kAttrStreamSwitchKind, MakeValue<uint32_t>(kIndependentStreamSwitch), independent_switch_app);
|
||||
exec_order.push_back(independent_switch_app);
|
||||
MS_LOG(INFO) << "Independent op loop insert Stream Switch " << independent_switch_app->fullname_with_scope();
|
||||
}
|
||||
|
||||
// fpbp loop process
|
||||
// fpbp loop stream switch
|
||||
uint32_t fpbp_switch_stream_id = resource_manager.ApplyNewStream();
|
||||
uint32_t fpbp_stream_id = resource_manager.ApplyNewStream();
|
||||
CNodePtr fpbp_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input, kFpBpStreamSwitch);
|
||||
MS_EXCEPTION_IF_NULL(fpbp_switch_app);
|
||||
uint32_t fpbp_switch_stream_id = resource_manager.ApplyNewStream();
|
||||
AnfAlgo::SetStreamId(fpbp_switch_stream_id, fpbp_switch_app.get());
|
||||
AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue<bool>(true), fpbp_switch_app);
|
||||
|
||||
exec_order.push_back(fpbp_switch_app);
|
||||
|
||||
// fpbp loop fpbp start recv
|
||||
CNodePtr fpbp_start_recv = CreateRecvApplyKernel(kernel_graph_ptr, fpbp_start_event_id);
|
||||
uint32_t fpbp_stream_id = resource_manager.ApplyNewStream();
|
||||
AnfAlgo::SetStreamId(fpbp_stream_id, fpbp_start_recv.get());
|
||||
exec_order.push_back(fpbp_start_recv);
|
||||
|
||||
// update fpbp loop stream switch true_branch_stream attr
|
||||
AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(fpbp_stream_id), fpbp_switch_app);
|
||||
AnfAlgo::SetNodeAttr(kAttrStreamSwitchKind, MakeValue<uint32_t>(kFpBpStreamSwitch), fpbp_switch_app);
|
||||
exec_order.push_back(fpbp_switch_app);
|
||||
MS_LOG(INFO) << "FpBp loop insert Stream Switch " << fpbp_switch_app->fullname_with_scope();
|
||||
|
||||
if (exist_getnext) {
|
||||
// fpbp loop fpbp start recv
|
||||
CNodePtr fpbp_start_recv = CreateRecvApplyKernel(kernel_graph_ptr, fpbp_start_event_id);
|
||||
AnfAlgo::SetStreamId(fpbp_stream_id, fpbp_start_recv.get());
|
||||
exec_order.push_back(fpbp_start_recv);
|
||||
MS_LOG(INFO) << "FpBp loop insert FpBp start Recv " << fpbp_start_recv->fullname_with_scope();
|
||||
}
|
||||
|
||||
// next loop AssignAdd
|
||||
CNodePtr assign_add_one = CreateStreamAssignAddnOP(kernel_graph_ptr, switch_loop_input, false);
|
||||
MS_EXCEPTION_IF_NULL(assign_add_one);
|
||||
AnfAlgo::SetStreamId(fpbp_stream_id, assign_add_one.get());
|
||||
exec_order.push_back(assign_add_one);
|
||||
MS_LOG(INFO) << "FpBp loop insert next loop AssignAdd " << assign_add_one->fullname_with_scope();
|
||||
|
||||
// fpbp memcpy
|
||||
// fpbp getnext output memcpy
|
||||
std::vector<CNodePtr> memcpy_list;
|
||||
std::vector<CNodePtr> other_list;
|
||||
CNodePtr cur_cnode = nullptr;
|
||||
for (size_t idx = i + 1; idx < orders.size(); idx++) {
|
||||
cur_cnode = orders[idx];
|
||||
if (AnfAlgo::HasNodeAttr(kAttrLabelForInsertStreamActive, cur_cnode)) {
|
||||
auto pre_node = orders[idx - 1];
|
||||
auto pre_kernel_name = AnfAlgo::GetCNodeName(pre_node);
|
||||
if (pre_kernel_name == kAtomicAddrCleanOpName) {
|
||||
other_list.pop_back();
|
||||
memcpy_list.push_back(pre_node);
|
||||
if (exist_getnext) {
|
||||
CNodePtr cur_cnode = nullptr;
|
||||
for (size_t idx = i + 1; idx < orders.size(); idx++) {
|
||||
cur_cnode = orders[idx];
|
||||
if (AnfAlgo::HasNodeAttr(kAttrLabelForInsertStreamActive, cur_cnode)) {
|
||||
auto pre_node = orders[idx - 1];
|
||||
auto pre_kernel_name = AnfAlgo::GetCNodeName(pre_node);
|
||||
if (pre_kernel_name == kAtomicAddrCleanOpName) {
|
||||
other_list.pop_back();
|
||||
memcpy_list.push_back(pre_node);
|
||||
}
|
||||
memcpy_list.emplace_back(cur_cnode);
|
||||
} else {
|
||||
other_list.emplace_back(cur_cnode);
|
||||
}
|
||||
memcpy_list.emplace_back(cur_cnode);
|
||||
} else {
|
||||
other_list.emplace_back(cur_cnode);
|
||||
}
|
||||
(void)std::copy(memcpy_list.begin(), memcpy_list.end(), std::back_inserter(exec_order));
|
||||
} else {
|
||||
other_list = orders;
|
||||
}
|
||||
|
||||
(void)std::copy(memcpy_list.begin(), memcpy_list.end(), std::back_inserter(exec_order));
|
||||
|
||||
// fpbp loop eos done recv
|
||||
if (eos_mode) {
|
||||
CNodePtr eos_done_recv = CreateRecvApplyKernel(kernel_graph_ptr, eos_done_event_id);
|
||||
AnfAlgo::SetStreamId(fpbp_stream_id, eos_done_recv.get());
|
||||
exec_order.push_back(eos_done_recv);
|
||||
MS_LOG(INFO) << "FpBp loop insert EoS done Recv " << eos_done_recv->fullname_with_scope();
|
||||
}
|
||||
|
||||
// stream active to activate getnext loop
|
||||
CNodePtr getnext_active_app = CreateStreamActiveOp(kernel_graph_ptr);
|
||||
MS_EXCEPTION_IF_NULL(getnext_active_app);
|
||||
getnext_active_streams.push_back(getnext_switch_stream_id);
|
||||
AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(getnext_active_streams),
|
||||
getnext_active_app);
|
||||
exec_order.push_back(getnext_active_app);
|
||||
if (exist_getnext) {
|
||||
CNodePtr getnext_active_app = CreateStreamActiveOp(kernel_graph_ptr);
|
||||
MS_EXCEPTION_IF_NULL(getnext_active_app);
|
||||
getnext_active_streams.push_back(getnext_switch_stream_id);
|
||||
AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(getnext_active_streams),
|
||||
getnext_active_app);
|
||||
exec_order.push_back(getnext_active_app);
|
||||
MS_LOG(INFO) << "FpBp loop insert GetNext loop Stream Active " << getnext_active_app->fullname_with_scope();
|
||||
}
|
||||
|
||||
// fpbp loop other ops
|
||||
(void)std::copy(other_list.begin(), other_list.end(), std::back_inserter(exec_order));
|
||||
|
@ -315,7 +356,9 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph>
|
|||
// current assign add op
|
||||
CNodePtr cur_assign_add = CreateStreamAssignAddnOP(kernel_graph_ptr, switch_loop_input, true);
|
||||
MS_EXCEPTION_IF_NULL(cur_assign_add);
|
||||
AnfAlgo::SetNodeAttr(kAttrFpBpEnd, MakeValue<bool>(true), cur_assign_add);
|
||||
exec_order.push_back(cur_assign_add);
|
||||
MS_LOG(INFO) << "FpBp loop insert current loop AssignAdd " << cur_assign_add->fullname_with_scope();
|
||||
|
||||
// stream active to activate fpbp loop and eos loop
|
||||
CNodePtr fpbp_active_app = CreateStreamActiveOp(kernel_graph_ptr);
|
||||
|
@ -323,6 +366,7 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph>
|
|||
fpbp_active_streams.push_back(fpbp_switch_stream_id);
|
||||
AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(fpbp_active_streams), fpbp_active_app);
|
||||
exec_order.push_back(fpbp_active_app);
|
||||
MS_LOG(INFO) << "FpBp loop insert FpBp loop and Eos loop Stream Active " << fpbp_active_app->fullname_with_scope();
|
||||
|
||||
kernel_graph_ptr->set_execution_order(exec_order);
|
||||
}
|
||||
|
|
|
@ -86,7 +86,8 @@ class KernelAdjust {
|
|||
void LoadSwitchInputs(std::vector<tensor::TensorPtr> *inputs);
|
||||
void InsertProfilingKernel(const ProfilingTraceInfo &profiling_trace_info,
|
||||
NotNull<session::KernelGraph *> kernel_graph_ptr);
|
||||
bool ExitIndependent(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
||||
bool ExistIndependent(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr);
|
||||
bool ExistGetNext(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr);
|
||||
};
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -317,6 +317,7 @@ constexpr auto kAttrOutputUsedNum = "output_used_num";
|
|||
constexpr auto kAttrHasBias = "has_bias";
|
||||
constexpr auto kAttrN = "n";
|
||||
constexpr auto kAttrLabelForInsertStreamActive = "label_for_insert_stream_active";
|
||||
constexpr auto kAttrFpBpEnd = "fpbp_end";
|
||||
constexpr auto kAttrFusion = "fusion";
|
||||
constexpr auto kAttrGroup = "group";
|
||||
constexpr auto kAttrOp = "op";
|
||||
|
|
Loading…
Reference in New Issue