forked from mindspore-Ecosystem/mindspore
parent
abe89923d8
commit
6362e954df
|
@ -1163,31 +1163,5 @@ bool AnfRuntimeAlgorithm::IsCondControlKernel(const CNodePtr &node) {
|
||||||
auto input = node->input(kAnfPrimitiveIndex);
|
auto input = node->input(kAnfPrimitiveIndex);
|
||||||
return IsPrimitive(input, prim::kPrimLabelGoto) || IsPrimitive(input, prim::kPrimLabelSwitch);
|
return IsPrimitive(input, prim::kPrimLabelGoto) || IsPrimitive(input, prim::kPrimLabelSwitch);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AnfRuntimeAlgorithm::IsIndependentNode(const CNodePtr &node) {
|
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
|
||||||
if (AnfAlgo::GetKernelType(node) != AICPU_KERNEL) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (AnfAlgo::GetCNodeName(node) == kGetNextOpName) {
|
|
||||||
MS_LOG(INFO) << "GetNext should not be independent node";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t input_nums = AnfAlgo::GetInputTensorNum(node);
|
|
||||||
if (input_nums == 0) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto inputs = node->inputs();
|
|
||||||
for (size_t i = 1; i < inputs.size(); i++) {
|
|
||||||
if (!inputs[i]->isa<ValueNode>()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace session
|
} // namespace session
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -212,7 +212,6 @@ class AnfRuntimeAlgorithm {
|
||||||
// get fix output precision from prev node, input_idx is the input index of current node related to prev node.
|
// get fix output precision from prev node, input_idx is the input index of current node related to prev node.
|
||||||
static TypeId GetPrevNodeOutputPrecision(const AnfNodePtr &node, size_t input_idx);
|
static TypeId GetPrevNodeOutputPrecision(const AnfNodePtr &node, size_t input_idx);
|
||||||
static bool IsCondControlKernel(const CNodePtr &node);
|
static bool IsCondControlKernel(const CNodePtr &node);
|
||||||
static bool IsIndependentNode(const CNodePtr &node);
|
|
||||||
};
|
};
|
||||||
} // namespace session
|
} // namespace session
|
||||||
using AnfAlgo = session::AnfRuntimeAlgorithm;
|
using AnfAlgo = session::AnfRuntimeAlgorithm;
|
||||||
|
|
|
@ -180,32 +180,20 @@ size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vecto
|
||||||
if (inputs_params == nullptr) {
|
if (inputs_params == nullptr) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
if (inputs_params->size() < 3) {
|
if (inputs_params->size() < 2) {
|
||||||
MS_LOG(EXCEPTION) << "Illegal inputs_params size";
|
MS_LOG(EXCEPTION) << "Illegal inputs_params size";
|
||||||
}
|
}
|
||||||
// update current loop tensor to 0 per iterator
|
auto tensor = (*inputs_params)[0];
|
||||||
auto cur_loop_tensor = (*inputs_params)[0];
|
MS_EXCEPTION_IF_NULL(tensor);
|
||||||
MS_EXCEPTION_IF_NULL(cur_loop_tensor);
|
auto *val = static_cast<int32_t *>(tensor->data_c());
|
||||||
auto *cur_val = static_cast<int32_t *>(cur_loop_tensor->data_c());
|
MS_EXCEPTION_IF_NULL(val);
|
||||||
MS_EXCEPTION_IF_NULL(cur_val);
|
*val = 0;
|
||||||
*cur_val = 0;
|
tensor->set_dirty(true);
|
||||||
cur_loop_tensor->set_dirty(true);
|
|
||||||
// set loop_count to zero
|
// set loop_count to zero
|
||||||
MS_EXCEPTION_IF_NULL(inputs);
|
MS_EXCEPTION_IF_NULL(inputs);
|
||||||
inputs->push_back(cur_loop_tensor);
|
inputs->push_back(tensor);
|
||||||
|
|
||||||
// update next loop tensor to 0 per iterator
|
auto epoch_tensor = (*inputs_params)[1];
|
||||||
auto next_loop_tensor = (*inputs_params)[1];
|
|
||||||
MS_EXCEPTION_IF_NULL(next_loop_tensor);
|
|
||||||
auto *next_val = static_cast<int32_t *>(next_loop_tensor->data_c());
|
|
||||||
MS_EXCEPTION_IF_NULL(next_val);
|
|
||||||
*next_val = 0;
|
|
||||||
next_loop_tensor->set_dirty(true);
|
|
||||||
// set loop_count to zero
|
|
||||||
MS_EXCEPTION_IF_NULL(inputs);
|
|
||||||
inputs->push_back(next_loop_tensor);
|
|
||||||
|
|
||||||
auto epoch_tensor = (*inputs_params)[2];
|
|
||||||
MS_EXCEPTION_IF_NULL(epoch_tensor);
|
MS_EXCEPTION_IF_NULL(epoch_tensor);
|
||||||
auto *epoch_val = static_cast<int32_t *>(epoch_tensor->data_c());
|
auto *epoch_val = static_cast<int32_t *>(epoch_tensor->data_c());
|
||||||
MS_EXCEPTION_IF_NULL(epoch_val);
|
MS_EXCEPTION_IF_NULL(epoch_val);
|
||||||
|
@ -942,7 +930,7 @@ bool TensorNeedSync(const AnfNodePtr ¶meter, const tensor::TensorPtr &tensor
|
||||||
void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
||||||
const std::vector<tensor::TensorPtr> &inputs_const) const {
|
const std::vector<tensor::TensorPtr> &inputs_const) const {
|
||||||
std::vector<tensor::TensorPtr> inputs(inputs_const);
|
std::vector<tensor::TensorPtr> inputs(inputs_const);
|
||||||
size_t input_ctrl_size = 3;
|
size_t input_ctrl_size = 2;
|
||||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||||
if (kernel_graph->input_ctrl_tensors()) {
|
if (kernel_graph->input_ctrl_tensors()) {
|
||||||
input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs);
|
input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs);
|
||||||
|
@ -952,7 +940,7 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
|
||||||
auto params = AnfAlgo::GetAllOutput(input_node);
|
auto params = AnfAlgo::GetAllOutput(input_node);
|
||||||
std::copy(params.begin(), params.end(), std::back_inserter(input_nodes));
|
std::copy(params.begin(), params.end(), std::back_inserter(input_nodes));
|
||||||
}
|
}
|
||||||
if ((inputs.size() + input_ctrl_size) - 3 != input_nodes.size()) {
|
if ((inputs.size() + input_ctrl_size) - 2 != input_nodes.size()) {
|
||||||
MS_LOG(EXCEPTION) << "Tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size()
|
MS_LOG(EXCEPTION) << "Tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size()
|
||||||
<< ", input_ctrl_size:" << input_ctrl_size;
|
<< ", input_ctrl_size:" << input_ctrl_size;
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,9 +42,6 @@ void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr)
|
||||||
InsertStreamActive(graph_ptr);
|
InsertStreamActive(graph_ptr);
|
||||||
InsertEventForHcomParallel(graph_ptr);
|
InsertEventForHcomParallel(graph_ptr);
|
||||||
InsertEventForIndependentParallel(graph_ptr);
|
InsertEventForIndependentParallel(graph_ptr);
|
||||||
GetIndependentMaxTarget(graph_ptr);
|
|
||||||
InsertCtrlForIndependentParallel(graph_ptr);
|
|
||||||
|
|
||||||
GetNeedActiveStreams(graph_ptr);
|
GetNeedActiveStreams(graph_ptr);
|
||||||
graph_ptr->PrintGraphExecuteOrder();
|
graph_ptr->PrintGraphExecuteOrder();
|
||||||
CheckResourceAssign(graph_ptr);
|
CheckResourceAssign(graph_ptr);
|
||||||
|
@ -69,7 +66,7 @@ void AscendStreamAssign::ReorderIndependentOrders(const NotNull<KernelGraphPtr>
|
||||||
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
|
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
|
||||||
auto cur_cnode_ptr = cnode_ptr_list[i];
|
auto cur_cnode_ptr = cnode_ptr_list[i];
|
||||||
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
|
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
|
||||||
if (AnfAlgo::IsIndependentNode(cur_cnode_ptr)) {
|
if (IsIndependentNode(cur_cnode_ptr)) {
|
||||||
independents.emplace_back(cur_cnode_ptr);
|
independents.emplace_back(cur_cnode_ptr);
|
||||||
} else {
|
} else {
|
||||||
others.emplace_back(cur_cnode_ptr);
|
others.emplace_back(cur_cnode_ptr);
|
||||||
|
@ -136,7 +133,7 @@ void AscendStreamAssign::AssignAllNodesStream(const NotNull<KernelGraphPtr> &gra
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (AnfAlgo::IsIndependentNode(cur_cnode_ptr)) {
|
if (IsIndependentNode(cur_cnode_ptr)) {
|
||||||
exit_independent = true;
|
exit_independent = true;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -168,7 +165,7 @@ void AscendStreamAssign::AssignAllNodesStream(const NotNull<KernelGraphPtr> &gra
|
||||||
if (AnfAlgo::GetStreamId(cur_cnode_ptr) != kInvalidStreamId) {
|
if (AnfAlgo::GetStreamId(cur_cnode_ptr) != kInvalidStreamId) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (AnfAlgo::IsIndependentNode(cur_cnode_ptr)) {
|
if (IsIndependentNode(cur_cnode_ptr)) {
|
||||||
AssignIndependentStreamId(cur_cnode_ptr);
|
AssignIndependentStreamId(cur_cnode_ptr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -245,6 +242,33 @@ void AscendStreamAssign::AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool AscendStreamAssign::IsIndependentNode(const CNodePtr &node_ptr) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node_ptr);
|
||||||
|
if (AnfAlgo::GetKernelType(node_ptr) != AICPU_KERNEL) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (AnfAlgo::GetCNodeName(node_ptr) == kGetNextOpName) {
|
||||||
|
MS_LOG(INFO) << "GetNext should not be independent node";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t input_nums = AnfAlgo::GetInputTensorNum(node_ptr);
|
||||||
|
if (input_nums == 0) {
|
||||||
|
MS_LOG(INFO) << "Node " << node_ptr->fullname_with_scope() << " is independent, as inputs nums is zero";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto inputs = node_ptr->inputs();
|
||||||
|
for (size_t i = 1; i < inputs.size(); i++) {
|
||||||
|
if (!inputs[i]->isa<ValueNode>()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "Node " << node_ptr->fullname_with_scope() << " is independent, as inputs is all value node";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
// section 3:
|
// section 3:
|
||||||
void AscendStreamAssign::UpdateAtomicAddrCleanStreamId(const NotNull<KernelGraphPtr> &graph_ptr) {
|
void AscendStreamAssign::UpdateAtomicAddrCleanStreamId(const NotNull<KernelGraphPtr> &graph_ptr) {
|
||||||
MS_LOG(INFO) << "Start";
|
MS_LOG(INFO) << "Start";
|
||||||
|
@ -269,11 +293,13 @@ void AscendStreamAssign::InsertStreamActive(const NotNull<KernelGraphPtr> &graph
|
||||||
CNodePtr pre_cnode_ptr = nullptr;
|
CNodePtr pre_cnode_ptr = nullptr;
|
||||||
uint32_t pre_stream_id = UINT32_MAX;
|
uint32_t pre_stream_id = UINT32_MAX;
|
||||||
|
|
||||||
|
bool independent_flag = !(independent_stream_map_.empty());
|
||||||
|
bool hcom_flag = !(hcom_stream_map_.empty());
|
||||||
auto cnode_ptr_list = graph_ptr->execution_order();
|
auto cnode_ptr_list = graph_ptr->execution_order();
|
||||||
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
|
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
|
||||||
cur_cnode_ptr = cnode_ptr_list[i];
|
cur_cnode_ptr = cnode_ptr_list[i];
|
||||||
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
|
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
|
||||||
if (AnfAlgo::IsIndependentNode(cur_cnode_ptr)) {
|
if (IsIndependentNode(cur_cnode_ptr)) {
|
||||||
update_cnode_list.emplace_back(cur_cnode_ptr);
|
update_cnode_list.emplace_back(cur_cnode_ptr);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -296,7 +322,7 @@ void AscendStreamAssign::InsertStreamActive(const NotNull<KernelGraphPtr> &graph
|
||||||
update_cnode_list.emplace_back(active_ptr);
|
update_cnode_list.emplace_back(active_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName) {
|
if ((independent_flag || hcom_flag) && (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName)) {
|
||||||
MS_LOG(INFO) << "Insert StreamActive op after FP StreamSwitch for stream parallel";
|
MS_LOG(INFO) << "Insert StreamActive op after FP StreamSwitch for stream parallel";
|
||||||
UpdateStreamSwitch(graph_ptr, cur_cnode_ptr, &update_cnode_list);
|
UpdateStreamSwitch(graph_ptr, cur_cnode_ptr, &update_cnode_list);
|
||||||
} else {
|
} else {
|
||||||
|
@ -320,10 +346,8 @@ void AscendStreamAssign::GetProcessedStream(const NotNull<KernelGraphPtr> &graph
|
||||||
uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr);
|
uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr);
|
||||||
|
|
||||||
if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName) {
|
if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName) {
|
||||||
if (AnfAlgo::HasNodeAttr(kAttrTrueBranchStream, cur_cnode_ptr)) {
|
|
||||||
auto true_stream_id = AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode_ptr, kAttrTrueBranchStream);
|
auto true_stream_id = AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode_ptr, kAttrTrueBranchStream);
|
||||||
processed_streams_.emplace(true_stream_id);
|
processed_streams_.emplace(true_stream_id);
|
||||||
}
|
|
||||||
|
|
||||||
if (!AnfAlgo::HasNodeAttr(kStreamNeedActivedFirst, cur_cnode_ptr)) {
|
if (!AnfAlgo::HasNodeAttr(kStreamNeedActivedFirst, cur_cnode_ptr)) {
|
||||||
continue;
|
continue;
|
||||||
|
@ -341,78 +365,46 @@ void AscendStreamAssign::GetProcessedStream(const NotNull<KernelGraphPtr> &graph
|
||||||
|
|
||||||
void AscendStreamAssign::UpdateStreamSwitch(const NotNull<KernelGraphPtr> &graph_ptr, const CNodePtr &switch_ptr,
|
void AscendStreamAssign::UpdateStreamSwitch(const NotNull<KernelGraphPtr> &graph_ptr, const CNodePtr &switch_ptr,
|
||||||
vector<CNodePtr> *orders) {
|
vector<CNodePtr> *orders) {
|
||||||
if (!AnfAlgo::HasNodeAttr(kStreamNeedActivedFirst, switch_ptr)) {
|
|
||||||
orders->emplace_back(switch_ptr);
|
orders->emplace_back(switch_ptr);
|
||||||
|
if (!AnfAlgo::HasNodeAttr(kStreamNeedActivedFirst, switch_ptr)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto need_active = AnfAlgo::GetNodeAttr<bool>(switch_ptr, kStreamNeedActivedFirst);
|
auto need_active = AnfAlgo::GetNodeAttr<bool>(switch_ptr, kStreamNeedActivedFirst);
|
||||||
if (!need_active) {
|
if (!need_active) {
|
||||||
orders->emplace_back(switch_ptr);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!AnfAlgo::HasNodeAttr(kAttrStreamSwitchKind, switch_ptr)) {
|
MS_EXCEPTION_IF_NULL(switch_ptr);
|
||||||
orders->emplace_back(switch_ptr);
|
auto true_stream_id = AnfAlgo::GetNodeAttr<uint32_t>(switch_ptr, kAttrTrueBranchStream);
|
||||||
return;
|
MS_LOG(INFO) << "Streamswtich stream id:" << AnfAlgo::GetStreamId(switch_ptr)
|
||||||
}
|
<< "; active stream id:" << true_stream_id;
|
||||||
auto kind = AnfAlgo::GetNodeAttr<uint32_t>(switch_ptr, kAttrStreamSwitchKind);
|
|
||||||
if (kind == kEosStreamSwitch || kind == kGetNextStreamSwitch) {
|
|
||||||
orders->emplace_back(switch_ptr);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (kind == kIndependentStreamSwitch) {
|
CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr);
|
||||||
bool independent_empty = independent_stream_map_.empty();
|
AnfAlgo::SetStreamId(true_stream_id, active_ptr.get());
|
||||||
// if indepdent empty: delete independent streamswitch
|
vector<uint32_t> active_ids;
|
||||||
if (!independent_empty) {
|
// active indepdent stream
|
||||||
for (const auto &item : independent_stream_map_) {
|
for (const auto &item : independent_stream_map_) {
|
||||||
// first independetn stream id is minimum and order by std map;
|
active_ids.emplace_back(item.first);
|
||||||
auto first_independent_stream = item.first;
|
|
||||||
AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(first_independent_stream), switch_ptr);
|
|
||||||
orders->emplace_back(switch_ptr);
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
} else {
|
// active hcom stream
|
||||||
MS_LOG(ERROR) << "independent stream switch exit, but independent stream is empty";
|
for (const auto &item : hcom_stream_map_) {
|
||||||
|
active_ids.emplace_back(item.first);
|
||||||
}
|
}
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(active_ids), active_ptr);
|
||||||
|
|
||||||
// update processed stream
|
// update processed stream
|
||||||
independent_stream_activated_ = true;
|
independent_stream_activated_ = true;
|
||||||
for (const auto &item : independent_stream_map_) {
|
for (const auto &item : independent_stream_map_) {
|
||||||
processed_streams_.emplace(item.first);
|
processed_streams_.emplace(item.first);
|
||||||
}
|
}
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (kind == kFpBpStreamSwitch) {
|
|
||||||
bool hcom_empty = hcom_stream_map_.empty();
|
|
||||||
if (hcom_empty) {
|
|
||||||
orders->emplace_back(switch_ptr);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (!AnfAlgo::HasNodeAttr(kAttrTrueBranchStream, switch_ptr)) {
|
|
||||||
orders->emplace_back(switch_ptr);
|
|
||||||
MS_LOG(WARNING) << "FpBp StreamSwitch has no true branch attr";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
auto true_stream_id = AnfAlgo::GetNodeAttr<uint32_t>(switch_ptr, kAttrTrueBranchStream);
|
|
||||||
MS_LOG(INFO) << "Streamswtich stream id:" << AnfAlgo::GetStreamId(switch_ptr)
|
|
||||||
<< "; active stream id:" << true_stream_id;
|
|
||||||
CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr);
|
|
||||||
AnfAlgo::SetStreamId(true_stream_id, active_ptr.get());
|
|
||||||
vector<uint32_t> active_ids;
|
|
||||||
// active hcom stream
|
|
||||||
for (const auto &item : hcom_stream_map_) {
|
|
||||||
active_ids.emplace_back(item.first);
|
|
||||||
}
|
|
||||||
AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(active_ids), active_ptr);
|
|
||||||
hcom_stream_activated_ = true;
|
hcom_stream_activated_ = true;
|
||||||
for (const auto &item : hcom_stream_map_) {
|
for (const auto &item : hcom_stream_map_) {
|
||||||
processed_streams_.emplace(item.first);
|
processed_streams_.emplace(item.first);
|
||||||
}
|
}
|
||||||
orders->emplace_back(switch_ptr);
|
|
||||||
orders->emplace_back(active_ptr);
|
orders->emplace_back(active_ptr);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AscendStreamAssign::IsProcessedStream(uint32_t stream_id) {
|
bool AscendStreamAssign::IsProcessedStream(uint32_t stream_id) {
|
||||||
|
@ -640,7 +632,7 @@ void AscendStreamAssign::InsertEventForIndependentParallel(const NotNull<KernelG
|
||||||
auto it = cnodes.begin();
|
auto it = cnodes.begin();
|
||||||
while (it != cnodes.end()) {
|
while (it != cnodes.end()) {
|
||||||
MS_EXCEPTION_IF_NULL(*it);
|
MS_EXCEPTION_IF_NULL(*it);
|
||||||
if (AnfAlgo::IsIndependentNode(*it)) {
|
if (IsIndependentNode(*it)) {
|
||||||
MS_LOG(INFO) << "Deal independent op[" << (*it)->DebugString() << "]";
|
MS_LOG(INFO) << "Deal independent op[" << (*it)->DebugString() << "]";
|
||||||
CNodePtr send_cnode_ptr = CreateSendApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId(*it));
|
CNodePtr send_cnode_ptr = CreateSendApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId(*it));
|
||||||
it = cnodes.insert(it + 1, send_cnode_ptr);
|
it = cnodes.insert(it + 1, send_cnode_ptr);
|
||||||
|
@ -668,129 +660,6 @@ void AscendStreamAssign::InsertEventForIndependentParallel(const NotNull<KernelG
|
||||||
MS_LOG(INFO) << "End";
|
MS_LOG(INFO) << "End";
|
||||||
}
|
}
|
||||||
|
|
||||||
void AscendStreamAssign::GetIndependentMaxTarget(const NotNull<KernelGraphPtr> &graph_ptr) {
|
|
||||||
MS_LOG(INFO) << "Start";
|
|
||||||
auto cnode_ptr_list = graph_ptr->execution_order();
|
|
||||||
for (size_t i = 0; i < cnode_ptr_list.size(); i++) {
|
|
||||||
auto cur_node = cnode_ptr_list[i];
|
|
||||||
auto key = cur_node.get();
|
|
||||||
if (!AnfAlgo::IsIndependentNode(cur_node)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool flag = false;
|
|
||||||
for (size_t j = cnode_ptr_list.size() - 1; j > i; j--) {
|
|
||||||
auto target_node = cnode_ptr_list[j];
|
|
||||||
auto inputs = target_node->inputs();
|
|
||||||
for (size_t m = 1; m < inputs.size(); m++) {
|
|
||||||
auto input = inputs[m];
|
|
||||||
if (opt::IsNopNode(input)) {
|
|
||||||
CNodePtr cnode = input->cast<CNodePtr>();
|
|
||||||
auto new_inputs = cnode->inputs();
|
|
||||||
for (size_t k = 1; k < new_inputs.size(); k++) {
|
|
||||||
auto new_real_input = AnfAlgo::VisitKernel(new_inputs[k], 0);
|
|
||||||
if (key == new_real_input.first.get()) {
|
|
||||||
MS_LOG(INFO) << "Nop node find max target op:" << AnfAlgo::GetCNodeName(cur_node);
|
|
||||||
independent_targets_.emplace(target_node.get());
|
|
||||||
flag = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
auto real_input = AnfAlgo::VisitKernel(input, 0);
|
|
||||||
if (key == real_input.first.get()) {
|
|
||||||
MS_LOG(INFO) << "Find max target op:" << AnfAlgo::GetCNodeName(cur_node);
|
|
||||||
independent_targets_.emplace(target_node.get());
|
|
||||||
flag = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (flag) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
MS_LOG(INFO) << "End";
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t AscendStreamAssign::GetIndexByKey(const NotNull<KernelGraphPtr> &graph_ptr, const CNodeKey &key) {
|
|
||||||
auto &exe_orders = graph_ptr->execution_order();
|
|
||||||
for (uint32_t i = 0; i < exe_orders.size(); i++) {
|
|
||||||
CNodeKey node_key = exe_orders[i].get();
|
|
||||||
if (node_key == key) {
|
|
||||||
return i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return UINT32_MAX;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t AscendStreamAssign::GetMaxIndexTarget(const NotNull<KernelGraphPtr> &graph_ptr) {
|
|
||||||
if (independent_targets_.empty()) {
|
|
||||||
return UINT32_MAX;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::set<uint32_t> indexs;
|
|
||||||
for (const auto &key : independent_targets_) {
|
|
||||||
auto index = GetIndexByKey(graph_ptr, key);
|
|
||||||
if (index == UINT32_MAX) {
|
|
||||||
MS_LOG(EXCEPTION) << "graph has no correspond key";
|
|
||||||
}
|
|
||||||
indexs.emplace(index);
|
|
||||||
}
|
|
||||||
|
|
||||||
return *(std::max_element(indexs.begin(), indexs.end()));
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t AscendStreamAssign::GetIndependentStreamSwitchStreamId(const NotNull<KernelGraphPtr> &graph_ptr) {
|
|
||||||
auto &exe_orders = graph_ptr->execution_order();
|
|
||||||
for (const auto &item : exe_orders) {
|
|
||||||
if (AnfAlgo::GetCNodeName(item) == kStreamSwitchOpName) {
|
|
||||||
if (!AnfAlgo::HasNodeAttr(kAttrStreamSwitchKind, item)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto kind = AnfAlgo::GetNodeAttr<uint32_t>(item, kAttrStreamSwitchKind);
|
|
||||||
if (kind == kIndependentStreamSwitch) {
|
|
||||||
return AnfAlgo::GetStreamId(item);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return kInvalidStreamId;
|
|
||||||
}
|
|
||||||
|
|
||||||
void AscendStreamAssign::InsertCtrlForIndependentParallel(const NotNull<KernelGraphPtr> &graph_ptr) {
|
|
||||||
if (independent_targets_.empty()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t independent_switch_stream = GetIndependentStreamSwitchStreamId(graph_ptr);
|
|
||||||
if (independent_switch_stream == kInvalidStreamId) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto max_index = GetMaxIndexTarget(graph_ptr);
|
|
||||||
auto &exe_orders = graph_ptr->execution_order();
|
|
||||||
if (max_index >= exe_orders.size()) {
|
|
||||||
MS_LOG(EXCEPTION) << "max target index:" << max_index << " is greater than graph orders size:" << exe_orders.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
auto max_node_stream = AnfAlgo::GetStreamId(exe_orders[max_index]);
|
|
||||||
|
|
||||||
CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr);
|
|
||||||
// 1.set stream id
|
|
||||||
AnfAlgo::SetStreamId(max_node_stream, active_ptr.get());
|
|
||||||
// 2.set active stream ids
|
|
||||||
std::vector<uint32_t> active_index_list{independent_switch_stream};
|
|
||||||
AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(active_index_list), active_ptr);
|
|
||||||
|
|
||||||
std::vector<CNodePtr> update_cnode_list;
|
|
||||||
std::copy(exe_orders.begin(), exe_orders.begin() + max_index + 1, std::back_inserter(update_cnode_list));
|
|
||||||
update_cnode_list.emplace_back(active_ptr);
|
|
||||||
std::copy(exe_orders.begin() + max_index + 1, exe_orders.end(), std::back_inserter(update_cnode_list));
|
|
||||||
graph_ptr->set_execution_order(update_cnode_list);
|
|
||||||
}
|
|
||||||
|
|
||||||
// section7
|
// section7
|
||||||
void AscendStreamAssign::GetNeedActiveStreams(const NotNull<KernelGraphPtr> &graph_ptr) {
|
void AscendStreamAssign::GetNeedActiveStreams(const NotNull<KernelGraphPtr> &graph_ptr) {
|
||||||
CNodePtr cur_cnode_ptr = nullptr;
|
CNodePtr cur_cnode_ptr = nullptr;
|
||||||
|
@ -1048,7 +917,6 @@ void AscendStreamAssign::Reset() {
|
||||||
stream_groups_.clear();
|
stream_groups_.clear();
|
||||||
stream_relations_.clear();
|
stream_relations_.clear();
|
||||||
event_map_.clear();
|
event_map_.clear();
|
||||||
independent_targets_.clear();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// section 10
|
// section 10
|
||||||
|
|
|
@ -39,7 +39,6 @@ using std::shared_ptr;
|
||||||
using std::unordered_map;
|
using std::unordered_map;
|
||||||
using std::unordered_set;
|
using std::unordered_set;
|
||||||
using std::vector;
|
using std::vector;
|
||||||
using CNodeKey = void *;
|
|
||||||
const uint32_t kInvalidStreamId = UINT32_MAX;
|
const uint32_t kInvalidStreamId = UINT32_MAX;
|
||||||
const uint32_t kInvalidEventId = UINT32_MAX;
|
const uint32_t kInvalidEventId = UINT32_MAX;
|
||||||
class AscendResourceMng {
|
class AscendResourceMng {
|
||||||
|
@ -109,6 +108,8 @@ class AscendStreamAssign {
|
||||||
void AssignStream(const NotNull<KernelGraphPtr> &graph_ptr);
|
void AssignStream(const NotNull<KernelGraphPtr> &graph_ptr);
|
||||||
void GetHcomStreams(std::vector<uint32_t> *streams);
|
void GetHcomStreams(std::vector<uint32_t> *streams);
|
||||||
void GetWaitStreams(vector<uint32_t> *wait_active_stream_list);
|
void GetWaitStreams(vector<uint32_t> *wait_active_stream_list);
|
||||||
|
CNodePtr CreateSendApplyKernel(const NotNull<KernelGraphPtr> &graph_ptr, uint32_t event_id, uint32_t stream_id);
|
||||||
|
CNodePtr CreateRecvApplyKernel(const NotNull<KernelGraphPtr> &graph_ptr, uint32_t event_id, uint32_t stream_id);
|
||||||
const std::vector<std::vector<uint32_t>> &get_stream_group() const { return stream_groups_; }
|
const std::vector<std::vector<uint32_t>> &get_stream_group() const { return stream_groups_; }
|
||||||
const std::map<CNodePtr, CNodePtr> &get_event_map() const { return event_map_; }
|
const std::map<CNodePtr, CNodePtr> &get_event_map() const { return event_map_; }
|
||||||
|
|
||||||
|
@ -116,8 +117,6 @@ class AscendStreamAssign {
|
||||||
AscendStreamAssign() = default;
|
AscendStreamAssign() = default;
|
||||||
~AscendStreamAssign() = default;
|
~AscendStreamAssign() = default;
|
||||||
void Reset();
|
void Reset();
|
||||||
CNodePtr CreateSendApplyKernel(const NotNull<KernelGraphPtr> &graph_ptr, uint32_t event_id, uint32_t stream_id);
|
|
||||||
CNodePtr CreateRecvApplyKernel(const NotNull<KernelGraphPtr> &graph_ptr, uint32_t event_id, uint32_t stream_id);
|
|
||||||
void CheckResourceAssign(const NotNull<KernelGraphPtr> &graph_ptr);
|
void CheckResourceAssign(const NotNull<KernelGraphPtr> &graph_ptr);
|
||||||
void CheckStreamAssign(const NotNull<KernelGraphPtr> &graph_ptr);
|
void CheckStreamAssign(const NotNull<KernelGraphPtr> &graph_ptr);
|
||||||
void CheckEventAssign(const NotNull<KernelGraphPtr> &graph_ptr);
|
void CheckEventAssign(const NotNull<KernelGraphPtr> &graph_ptr);
|
||||||
|
@ -131,7 +130,6 @@ class AscendStreamAssign {
|
||||||
void UpdateStreamSwitch(const NotNull<KernelGraphPtr> &graph_ptr, const CNodePtr &switch_ptr,
|
void UpdateStreamSwitch(const NotNull<KernelGraphPtr> &graph_ptr, const CNodePtr &switch_ptr,
|
||||||
vector<CNodePtr> *orders);
|
vector<CNodePtr> *orders);
|
||||||
void InsertEventForIndependentParallel(const NotNull<KernelGraphPtr> &graph_ptr);
|
void InsertEventForIndependentParallel(const NotNull<KernelGraphPtr> &graph_ptr);
|
||||||
void InsertCtrlForIndependentParallel(const NotNull<KernelGraphPtr> &graph_ptr);
|
|
||||||
void InsertEventForHcomParallel(const NotNull<KernelGraphPtr> &graph_ptr);
|
void InsertEventForHcomParallel(const NotNull<KernelGraphPtr> &graph_ptr);
|
||||||
void InsertEventCommonDependHcom(const NotNull<KernelGraphPtr> &graph_ptr);
|
void InsertEventCommonDependHcom(const NotNull<KernelGraphPtr> &graph_ptr);
|
||||||
void InsertEventHcomDependCommon(const NotNull<KernelGraphPtr> &graph_ptr);
|
void InsertEventHcomDependCommon(const NotNull<KernelGraphPtr> &graph_ptr);
|
||||||
|
@ -143,10 +141,6 @@ class AscendStreamAssign {
|
||||||
void GetProcessedStream(const NotNull<KernelGraphPtr> &graph_ptr);
|
void GetProcessedStream(const NotNull<KernelGraphPtr> &graph_ptr);
|
||||||
void GetNeedActiveStreams(const NotNull<KernelGraphPtr> &graph_ptr);
|
void GetNeedActiveStreams(const NotNull<KernelGraphPtr> &graph_ptr);
|
||||||
void ReorderIndependentOrders(const NotNull<KernelGraphPtr> &graph_ptr);
|
void ReorderIndependentOrders(const NotNull<KernelGraphPtr> &graph_ptr);
|
||||||
uint32_t GetMaxIndexTarget(const NotNull<KernelGraphPtr> &graph_ptr);
|
|
||||||
uint32_t GetIndexByKey(const NotNull<KernelGraphPtr> &graph_ptr, const CNodeKey &key);
|
|
||||||
uint32_t GetIndependentStreamSwitchStreamId(const NotNull<KernelGraphPtr> &graph_ptr);
|
|
||||||
void GetIndependentMaxTarget(const NotNull<KernelGraphPtr> &graph_ptr);
|
|
||||||
|
|
||||||
bool IsTaskSink();
|
bool IsTaskSink();
|
||||||
bool IsHcom(const CNodePtr &cur_cnode_ptr);
|
bool IsHcom(const CNodePtr &cur_cnode_ptr);
|
||||||
|
@ -177,7 +171,6 @@ class AscendStreamAssign {
|
||||||
std::map<uint32_t, uint32_t> common_stream_map_{};
|
std::map<uint32_t, uint32_t> common_stream_map_{};
|
||||||
std::set<uint32_t> processed_streams_{};
|
std::set<uint32_t> processed_streams_{};
|
||||||
std::vector<uint32_t> need_first_active_streams_{};
|
std::vector<uint32_t> need_first_active_streams_{};
|
||||||
std::set<CNodeKey> independent_targets_;
|
|
||||||
|
|
||||||
// attr for memory copy reuse
|
// attr for memory copy reuse
|
||||||
std::map<uint32_t, std::vector<uint32_t>> stream_relations_{};
|
std::map<uint32_t, std::vector<uint32_t>> stream_relations_{};
|
||||||
|
|
|
@ -34,8 +34,8 @@ static constexpr uint32_t kTupleTaskId = 0;
|
||||||
static constexpr uint32_t kTupleStreamId = 1;
|
static constexpr uint32_t kTupleStreamId = 1;
|
||||||
static constexpr uint32_t kTupleArgs = 2;
|
static constexpr uint32_t kTupleArgs = 2;
|
||||||
static constexpr uint32_t kCurrentStepTensorIndex = 0;
|
static constexpr uint32_t kCurrentStepTensorIndex = 0;
|
||||||
static constexpr uint32_t kCurrentEpochTensorIndex = 2;
|
static constexpr uint32_t kCurrentEpochTensorIndex = 1;
|
||||||
static constexpr uint32_t kStepsPerEpochTensorIndex = 3;
|
static constexpr uint32_t kStepsPerEpochTensorIndex = 2;
|
||||||
static constexpr uint64_t kOpDebugShape = 2048;
|
static constexpr uint64_t kOpDebugShape = 2048;
|
||||||
static constexpr uint64_t kOpDebugHostMemSize = 2048;
|
static constexpr uint64_t kOpDebugHostMemSize = 2048;
|
||||||
static constexpr uint64_t kOpDebugDevMemSize = sizeof(void *);
|
static constexpr uint64_t kOpDebugDevMemSize = sizeof(void *);
|
||||||
|
|
|
@ -106,19 +106,6 @@ CNodePtr KernelAdjust::CreateRecvApplyKernel(const std::shared_ptr<session::Kern
|
||||||
return recv_node_ptr;
|
return recv_node_ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool KernelAdjust::ExitIndependent(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
|
|
||||||
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
|
|
||||||
const auto &exe_orders = kernel_graph_ptr->execution_order();
|
|
||||||
for (const auto &node : exe_orders) {
|
|
||||||
if (AnfAlgo::IsIndependentNode(node)) {
|
|
||||||
MS_LOG(INFO) << "graph exit independent node";
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
|
void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
|
||||||
device::ascend::AscendResourceMng &resource_manager = device::ascend::AscendResourceMng::GetInstance();
|
device::ascend::AscendResourceMng &resource_manager = device::ascend::AscendResourceMng::GetInstance();
|
||||||
resource_manager.ResetResource();
|
resource_manager.ResetResource();
|
||||||
|
@ -133,10 +120,10 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph>
|
||||||
|
|
||||||
std::vector<AnfNodePtr> *mute_inputs = kernel_graph_ptr->MutableInputs();
|
std::vector<AnfNodePtr> *mute_inputs = kernel_graph_ptr->MutableInputs();
|
||||||
MS_EXCEPTION_IF_NULL(mute_inputs);
|
MS_EXCEPTION_IF_NULL(mute_inputs);
|
||||||
mute_inputs->push_back(switch_loop_input[kCurLoopCountParamName]);
|
mute_inputs->push_back(switch_loop_input[kLoopCountParamName]);
|
||||||
mute_inputs->push_back(switch_loop_input[kNextLoopCountParamName]);
|
|
||||||
mute_inputs->push_back(switch_loop_input[kEpochParamName]);
|
mute_inputs->push_back(switch_loop_input[kEpochParamName]);
|
||||||
mute_inputs->push_back(switch_loop_input[kIterLoopParamName]);
|
mute_inputs->push_back(switch_loop_input[kIterLoopParamName]);
|
||||||
|
mute_inputs->push_back(switch_loop_input[kZeroParamName]);
|
||||||
mute_inputs->push_back(switch_loop_input[kOneParamName]);
|
mute_inputs->push_back(switch_loop_input[kOneParamName]);
|
||||||
for (const auto &input : kernel_graph_ptr->inputs()) {
|
for (const auto &input : kernel_graph_ptr->inputs()) {
|
||||||
MS_EXCEPTION_IF_NULL(input);
|
MS_EXCEPTION_IF_NULL(input);
|
||||||
|
@ -161,7 +148,7 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph>
|
||||||
|
|
||||||
// getnext loop process
|
// getnext loop process
|
||||||
// getnext loop stream switch op
|
// getnext loop stream switch op
|
||||||
CNodePtr getnext_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input, kGetNextStreamSwitch);
|
CNodePtr getnext_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input);
|
||||||
MS_EXCEPTION_IF_NULL(getnext_switch_app);
|
MS_EXCEPTION_IF_NULL(getnext_switch_app);
|
||||||
uint32_t getnext_switch_stream_id = resource_manager.ApplyNewStream();
|
uint32_t getnext_switch_stream_id = resource_manager.ApplyNewStream();
|
||||||
AnfAlgo::SetStreamId(getnext_switch_stream_id, getnext_switch_app.get());
|
AnfAlgo::SetStreamId(getnext_switch_stream_id, getnext_switch_app.get());
|
||||||
|
@ -181,9 +168,7 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph>
|
||||||
}
|
}
|
||||||
|
|
||||||
// update getnext loop stream switch true_branch_stream attr
|
// update getnext loop stream switch true_branch_stream attr
|
||||||
AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue<bool>(true), getnext_switch_app);
|
|
||||||
AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(getnext_stream_id), getnext_switch_app);
|
AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(getnext_stream_id), getnext_switch_app);
|
||||||
AnfAlgo::SetNodeAttr(kAttrStreamSwitchKind, MakeValue<uint32_t>(kGetNextStreamSwitch), getnext_switch_app);
|
|
||||||
|
|
||||||
// getnext loop fpbp start send
|
// getnext loop fpbp start send
|
||||||
uint32_t fpbp_start_event_id = resource_manager.ApplyNewEvent();
|
uint32_t fpbp_start_event_id = resource_manager.ApplyNewEvent();
|
||||||
|
@ -200,7 +185,7 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph>
|
||||||
|
|
||||||
// End Of Sequence loop process
|
// End Of Sequence loop process
|
||||||
// eos loop stream switch
|
// eos loop stream switch
|
||||||
CNodePtr eos_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input, kEosStreamSwitch);
|
CNodePtr eos_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input);
|
||||||
MS_EXCEPTION_IF_NULL(eos_switch_app);
|
MS_EXCEPTION_IF_NULL(eos_switch_app);
|
||||||
uint32_t eos_switch_stream_id = resource_manager.ApplyNewStream();
|
uint32_t eos_switch_stream_id = resource_manager.ApplyNewStream();
|
||||||
AnfAlgo::SetStreamId(eos_switch_stream_id, eos_switch_app.get());
|
AnfAlgo::SetStreamId(eos_switch_stream_id, eos_switch_app.get());
|
||||||
|
@ -215,7 +200,6 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph>
|
||||||
|
|
||||||
// update eos loop stream switch true_branch_stream attr
|
// update eos loop stream switch true_branch_stream attr
|
||||||
AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(eos_stream_id), eos_switch_app);
|
AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(eos_stream_id), eos_switch_app);
|
||||||
AnfAlgo::SetNodeAttr(kAttrStreamSwitchKind, MakeValue<uint32_t>(kEosStreamSwitch), eos_switch_app);
|
|
||||||
|
|
||||||
// EndOfSequence op
|
// EndOfSequence op
|
||||||
CNodePtr end_of_sequence_op = CreateEndOfSequenceOP(kernel_graph_ptr, getnext_cnode);
|
CNodePtr end_of_sequence_op = CreateEndOfSequenceOP(kernel_graph_ptr, getnext_cnode);
|
||||||
|
@ -233,27 +217,13 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph>
|
||||||
fpbp_active_streams.push_back(eos_switch_stream_id);
|
fpbp_active_streams.push_back(eos_switch_stream_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool exit_independent = ExitIndependent(kernel_graph_ptr);
|
|
||||||
if (exit_independent) {
|
|
||||||
// Independet parallel
|
|
||||||
CNodePtr independent_switch_app =
|
|
||||||
CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input, kIndependentStreamSwitch);
|
|
||||||
MS_EXCEPTION_IF_NULL(independent_switch_app);
|
|
||||||
uint32_t independent_switch_stream_id = resource_manager.ApplyNewStream();
|
|
||||||
AnfAlgo::SetStreamId(independent_switch_stream_id, independent_switch_app.get());
|
|
||||||
AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue<bool>(true), independent_switch_app);
|
|
||||||
AnfAlgo::SetNodeAttr(kAttrStreamSwitchKind, MakeValue<uint32_t>(kIndependentStreamSwitch), independent_switch_app);
|
|
||||||
exec_order.push_back(independent_switch_app);
|
|
||||||
}
|
|
||||||
|
|
||||||
// fpbp loop process
|
// fpbp loop process
|
||||||
// fpbp loop stream switch
|
// fpbp loop stream switch
|
||||||
CNodePtr fpbp_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input, kFpBpStreamSwitch);
|
CNodePtr fpbp_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input);
|
||||||
MS_EXCEPTION_IF_NULL(fpbp_switch_app);
|
MS_EXCEPTION_IF_NULL(fpbp_switch_app);
|
||||||
uint32_t fpbp_switch_stream_id = resource_manager.ApplyNewStream();
|
uint32_t fpbp_switch_stream_id = resource_manager.ApplyNewStream();
|
||||||
AnfAlgo::SetStreamId(fpbp_switch_stream_id, fpbp_switch_app.get());
|
AnfAlgo::SetStreamId(fpbp_switch_stream_id, fpbp_switch_app.get());
|
||||||
AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue<bool>(true), fpbp_switch_app);
|
AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue<bool>(true), fpbp_switch_app);
|
||||||
|
|
||||||
exec_order.push_back(fpbp_switch_app);
|
exec_order.push_back(fpbp_switch_app);
|
||||||
|
|
||||||
// fpbp loop fpbp start recv
|
// fpbp loop fpbp start recv
|
||||||
|
@ -264,9 +234,9 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph>
|
||||||
|
|
||||||
// update fpbp loop stream switch true_branch_stream attr
|
// update fpbp loop stream switch true_branch_stream attr
|
||||||
AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(fpbp_stream_id), fpbp_switch_app);
|
AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(fpbp_stream_id), fpbp_switch_app);
|
||||||
AnfAlgo::SetNodeAttr(kAttrStreamSwitchKind, MakeValue<uint32_t>(kFpBpStreamSwitch), fpbp_switch_app);
|
|
||||||
// next loop AssignAdd
|
// fpbp loop AssignAdd
|
||||||
CNodePtr assign_add_one = CreateStreamAssignAddnOP(kernel_graph_ptr, switch_loop_input, false);
|
CNodePtr assign_add_one = CreateStreamAssignAddnOP(kernel_graph_ptr, switch_loop_input);
|
||||||
MS_EXCEPTION_IF_NULL(assign_add_one);
|
MS_EXCEPTION_IF_NULL(assign_add_one);
|
||||||
AnfAlgo::SetStreamId(fpbp_stream_id, assign_add_one.get());
|
AnfAlgo::SetStreamId(fpbp_stream_id, assign_add_one.get());
|
||||||
exec_order.push_back(assign_add_one);
|
exec_order.push_back(assign_add_one);
|
||||||
|
@ -304,11 +274,6 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph>
|
||||||
// fpbp loop other ops
|
// fpbp loop other ops
|
||||||
(void)std::copy(other_list.begin(), other_list.end(), std::back_inserter(exec_order));
|
(void)std::copy(other_list.begin(), other_list.end(), std::back_inserter(exec_order));
|
||||||
|
|
||||||
// current assign add op
|
|
||||||
CNodePtr cur_assign_add = CreateStreamAssignAddnOP(kernel_graph_ptr, switch_loop_input, true);
|
|
||||||
MS_EXCEPTION_IF_NULL(cur_assign_add);
|
|
||||||
exec_order.push_back(cur_assign_add);
|
|
||||||
|
|
||||||
// stream active to activate fpbp loop and eos loop
|
// stream active to activate fpbp loop and eos loop
|
||||||
CNodePtr fpbp_active_app = CreateStreamActiveOp(kernel_graph_ptr);
|
CNodePtr fpbp_active_app = CreateStreamActiveOp(kernel_graph_ptr);
|
||||||
MS_EXCEPTION_IF_NULL(fpbp_active_app);
|
MS_EXCEPTION_IF_NULL(fpbp_active_app);
|
||||||
|
@ -331,19 +296,13 @@ void KernelAdjust::CreateSwitchOpParameters(const std::shared_ptr<session::Kerne
|
||||||
MS_LOG(EXCEPTION) << "create abstract before insert switch op failed!";
|
MS_LOG(EXCEPTION) << "create abstract before insert switch op failed!";
|
||||||
}
|
}
|
||||||
|
|
||||||
ParameterPtr cur_loop_count = std::make_shared<Parameter>(kernel_graph_ptr);
|
ParameterPtr loop_count = std::make_shared<Parameter>(kernel_graph_ptr);
|
||||||
MS_EXCEPTION_IF_NULL(cur_loop_count);
|
MS_EXCEPTION_IF_NULL(loop_count);
|
||||||
cur_loop_count->set_name(kCurLoopCountParamName);
|
loop_count->set_name(kLoopCountParamName);
|
||||||
cur_loop_count->set_abstract(paremeter_abstract_ptr);
|
loop_count->set_abstract(paremeter_abstract_ptr);
|
||||||
ParameterPtr loop_count_cur = kernel_graph_ptr->NewParameter(cur_loop_count);
|
ParameterPtr loop_count_new = kernel_graph_ptr->NewParameter(loop_count);
|
||||||
(*switch_loop_input)[kCurLoopCountParamName] = loop_count_cur;
|
|
||||||
|
|
||||||
ParameterPtr next_loop_count = std::make_shared<Parameter>(kernel_graph_ptr);
|
(*switch_loop_input)[kLoopCountParamName] = loop_count_new;
|
||||||
MS_EXCEPTION_IF_NULL(next_loop_count);
|
|
||||||
next_loop_count->set_name(kNextLoopCountParamName);
|
|
||||||
next_loop_count->set_abstract(paremeter_abstract_ptr);
|
|
||||||
ParameterPtr loop_count_next = kernel_graph_ptr->NewParameter(next_loop_count);
|
|
||||||
(*switch_loop_input)[kNextLoopCountParamName] = loop_count_next;
|
|
||||||
|
|
||||||
ParameterPtr iter_loop = std::make_shared<Parameter>(kernel_graph_ptr);
|
ParameterPtr iter_loop = std::make_shared<Parameter>(kernel_graph_ptr);
|
||||||
iter_loop->set_name(kIterLoopParamName);
|
iter_loop->set_name(kIterLoopParamName);
|
||||||
|
@ -351,6 +310,12 @@ void KernelAdjust::CreateSwitchOpParameters(const std::shared_ptr<session::Kerne
|
||||||
ParameterPtr iter_loop_new = kernel_graph_ptr->NewParameter(iter_loop);
|
ParameterPtr iter_loop_new = kernel_graph_ptr->NewParameter(iter_loop);
|
||||||
(*switch_loop_input)[kIterLoopParamName] = iter_loop_new;
|
(*switch_loop_input)[kIterLoopParamName] = iter_loop_new;
|
||||||
|
|
||||||
|
ParameterPtr zero = std::make_shared<Parameter>(kernel_graph_ptr);
|
||||||
|
zero->set_name(kZeroParamName);
|
||||||
|
zero->set_abstract(paremeter_abstract_ptr);
|
||||||
|
ParameterPtr zero_new = kernel_graph_ptr->NewParameter(zero);
|
||||||
|
(*switch_loop_input)[kZeroParamName] = zero_new;
|
||||||
|
|
||||||
ParameterPtr one = std::make_shared<Parameter>(kernel_graph_ptr);
|
ParameterPtr one = std::make_shared<Parameter>(kernel_graph_ptr);
|
||||||
one->set_name(kOneParamName);
|
one->set_name(kOneParamName);
|
||||||
one->set_abstract(paremeter_abstract_ptr);
|
one->set_abstract(paremeter_abstract_ptr);
|
||||||
|
@ -378,22 +343,14 @@ kernel::KernelBuildInfo::KernelBuildInfoBuilder KernelAdjust::CreateMngKernelBui
|
||||||
}
|
}
|
||||||
|
|
||||||
CNodePtr KernelAdjust::CreateStreamSwitchOp(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
|
CNodePtr KernelAdjust::CreateStreamSwitchOp(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
|
||||||
const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input,
|
const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input) {
|
||||||
StreamSwitchKind kind) {
|
|
||||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder(
|
kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder(
|
||||||
{kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32});
|
{kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32});
|
||||||
auto typeNone_abstract = std::make_shared<abstract::AbstractNone>();
|
auto typeNone_abstract = std::make_shared<abstract::AbstractNone>();
|
||||||
auto stream_switch = std::make_shared<Primitive>(kStreamSwitchOpName);
|
auto stream_switch = std::make_shared<Primitive>(kStreamSwitchOpName);
|
||||||
std::vector<AnfNodePtr> inputs;
|
std::vector<AnfNodePtr> inputs;
|
||||||
inputs.push_back(NewValueNode(stream_switch));
|
inputs.push_back(NewValueNode(stream_switch));
|
||||||
if (kind == kFpBpStreamSwitch || kind == kEosStreamSwitch) {
|
inputs.push_back(switch_loop_input.at(kLoopCountParamName));
|
||||||
inputs.push_back(switch_loop_input.at(kCurLoopCountParamName));
|
|
||||||
} else if (kind == kGetNextStreamSwitch || kind == kIndependentStreamSwitch) {
|
|
||||||
inputs.push_back(switch_loop_input.at(kNextLoopCountParamName));
|
|
||||||
} else {
|
|
||||||
MS_LOG(ERROR) << "unknown stream switch kind";
|
|
||||||
}
|
|
||||||
|
|
||||||
inputs.push_back(switch_loop_input.at(kIterLoopParamName));
|
inputs.push_back(switch_loop_input.at(kIterLoopParamName));
|
||||||
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
|
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
|
||||||
CNodePtr stream_switch_app = kernel_graph_ptr->NewCNode(inputs);
|
CNodePtr stream_switch_app = kernel_graph_ptr->NewCNode(inputs);
|
||||||
|
@ -476,9 +433,9 @@ CNodePtr KernelAdjust::CreateEndOfSequenceOP(const std::shared_ptr<session::Kern
|
||||||
return end_of_sequence_node;
|
return end_of_sequence_node;
|
||||||
}
|
}
|
||||||
|
|
||||||
CNodePtr KernelAdjust::CreateStreamAssignAddnOP(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
|
CNodePtr KernelAdjust::CreateStreamAssignAddnOP(
|
||||||
const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input,
|
const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
|
||||||
bool cur_loop) {
|
const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
|
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
|
||||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder(
|
kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder(
|
||||||
{kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32});
|
{kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32});
|
||||||
|
@ -488,12 +445,7 @@ CNodePtr KernelAdjust::CreateStreamAssignAddnOP(const std::shared_ptr<session::K
|
||||||
auto assign_add = std::make_shared<Primitive>(kAssignAddOpName);
|
auto assign_add = std::make_shared<Primitive>(kAssignAddOpName);
|
||||||
std::vector<AnfNodePtr> inputs;
|
std::vector<AnfNodePtr> inputs;
|
||||||
inputs.push_back(NewValueNode(assign_add));
|
inputs.push_back(NewValueNode(assign_add));
|
||||||
if (cur_loop) {
|
inputs.push_back(switch_loop_input.at(kLoopCountParamName));
|
||||||
inputs.push_back(switch_loop_input.at(kCurLoopCountParamName));
|
|
||||||
} else {
|
|
||||||
inputs.push_back(switch_loop_input.at(kNextLoopCountParamName));
|
|
||||||
}
|
|
||||||
|
|
||||||
inputs.push_back(switch_loop_input.at(kOneParamName));
|
inputs.push_back(switch_loop_input.at(kOneParamName));
|
||||||
CNodePtr assign_add_one = kernel_graph_ptr->NewCNode(inputs);
|
CNodePtr assign_add_one = kernel_graph_ptr->NewCNode(inputs);
|
||||||
MS_EXCEPTION_IF_NULL(assign_add_one);
|
MS_EXCEPTION_IF_NULL(assign_add_one);
|
||||||
|
@ -505,8 +457,8 @@ CNodePtr KernelAdjust::CreateStreamAssignAddnOP(const std::shared_ptr<session::K
|
||||||
AnfAlgo::SetNodeAttr("input_names", input_names_v, assign_add_one);
|
AnfAlgo::SetNodeAttr("input_names", input_names_v, assign_add_one);
|
||||||
AnfAlgo::SetNodeAttr("output_names", output_names_v, assign_add_one);
|
AnfAlgo::SetNodeAttr("output_names", output_names_v, assign_add_one);
|
||||||
selected_kernel_builder.SetKernelType(KernelType::TBE_KERNEL);
|
selected_kernel_builder.SetKernelType(KernelType::TBE_KERNEL);
|
||||||
MS_EXCEPTION_IF_NULL(switch_loop_input.at(kCurLoopCountParamName));
|
MS_EXCEPTION_IF_NULL(switch_loop_input.at(kLoopCountParamName));
|
||||||
assign_add_one->set_abstract(switch_loop_input.at(kCurLoopCountParamName)->abstract());
|
assign_add_one->set_abstract(switch_loop_input.at(kLoopCountParamName)->abstract());
|
||||||
return assign_add_one;
|
return assign_add_one;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -561,23 +513,14 @@ bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::KernelGraph
|
||||||
void KernelAdjust::LoadSwitchInputs(std::vector<tensor::TensorPtr> *inputs) {
|
void KernelAdjust::LoadSwitchInputs(std::vector<tensor::TensorPtr> *inputs) {
|
||||||
MS_LOG(INFO) << "---------------- LoadSwitchInputs---";
|
MS_LOG(INFO) << "---------------- LoadSwitchInputs---";
|
||||||
MS_EXCEPTION_IF_NULL(inputs);
|
MS_EXCEPTION_IF_NULL(inputs);
|
||||||
// current loop count
|
|
||||||
std::vector<int> shp = {1};
|
std::vector<int> shp = {1};
|
||||||
tensor::TensorPtr cur_loop_count = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp);
|
tensor::TensorPtr loop_count_tensor = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp);
|
||||||
MS_EXCEPTION_IF_NULL(cur_loop_count);
|
MS_EXCEPTION_IF_NULL(loop_count_tensor);
|
||||||
int32_t *val = nullptr;
|
int32_t *val = nullptr;
|
||||||
val = static_cast<int32_t *>(cur_loop_count->data_c());
|
val = static_cast<int32_t *>(loop_count_tensor->data_c());
|
||||||
MS_EXCEPTION_IF_NULL(val);
|
MS_EXCEPTION_IF_NULL(val);
|
||||||
*val = 0;
|
*val = 0;
|
||||||
inputs->push_back(cur_loop_count);
|
inputs->push_back(loop_count_tensor);
|
||||||
|
|
||||||
// next loop count
|
|
||||||
tensor::TensorPtr next_loop_count = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp);
|
|
||||||
MS_EXCEPTION_IF_NULL(next_loop_count);
|
|
||||||
val = static_cast<int32_t *>(next_loop_count->data_c());
|
|
||||||
MS_EXCEPTION_IF_NULL(val);
|
|
||||||
*val = 0;
|
|
||||||
inputs->push_back(next_loop_count);
|
|
||||||
|
|
||||||
// Epoch in device
|
// Epoch in device
|
||||||
tensor::TensorPtr epoch_tensor = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp);
|
tensor::TensorPtr epoch_tensor = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp);
|
||||||
|
@ -587,7 +530,6 @@ void KernelAdjust::LoadSwitchInputs(std::vector<tensor::TensorPtr> *inputs) {
|
||||||
*val = 0;
|
*val = 0;
|
||||||
inputs->push_back(epoch_tensor);
|
inputs->push_back(epoch_tensor);
|
||||||
|
|
||||||
// total loop count per iter
|
|
||||||
tensor::TensorPtr iter_loop_tensor = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp);
|
tensor::TensorPtr iter_loop_tensor = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp);
|
||||||
MS_EXCEPTION_IF_NULL(iter_loop_tensor);
|
MS_EXCEPTION_IF_NULL(iter_loop_tensor);
|
||||||
val = static_cast<int32_t *>(iter_loop_tensor->data_c());
|
val = static_cast<int32_t *>(iter_loop_tensor->data_c());
|
||||||
|
@ -596,6 +538,13 @@ void KernelAdjust::LoadSwitchInputs(std::vector<tensor::TensorPtr> *inputs) {
|
||||||
MS_LOG(INFO) << "iter_loop_tensor = " << *val;
|
MS_LOG(INFO) << "iter_loop_tensor = " << *val;
|
||||||
inputs->push_back(iter_loop_tensor);
|
inputs->push_back(iter_loop_tensor);
|
||||||
|
|
||||||
|
tensor::TensorPtr zero_tensor = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp);
|
||||||
|
MS_EXCEPTION_IF_NULL(zero_tensor);
|
||||||
|
val = static_cast<int32_t *>(zero_tensor->data_c());
|
||||||
|
MS_EXCEPTION_IF_NULL(val);
|
||||||
|
*val = 0;
|
||||||
|
inputs->push_back(zero_tensor);
|
||||||
|
|
||||||
tensor::TensorPtr one_tensor = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp);
|
tensor::TensorPtr one_tensor = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp);
|
||||||
MS_EXCEPTION_IF_NULL(one_tensor);
|
MS_EXCEPTION_IF_NULL(one_tensor);
|
||||||
val = static_cast<int32_t *>(one_tensor->data_c());
|
val = static_cast<int32_t *>(one_tensor->data_c());
|
||||||
|
|
|
@ -33,19 +33,13 @@
|
||||||
using mindspore::device::ascend::ProfilingTraceInfo;
|
using mindspore::device::ascend::ProfilingTraceInfo;
|
||||||
using mindspore::device::ascend::ProfilingUtils;
|
using mindspore::device::ascend::ProfilingUtils;
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
constexpr auto kCurLoopCountParamName = "cur_loop_count";
|
constexpr auto kLoopCountParamName = "loop_count";
|
||||||
constexpr auto kNextLoopCountParamName = "next_loop_count";
|
|
||||||
constexpr auto kIterLoopParamName = "iter_loop";
|
constexpr auto kIterLoopParamName = "iter_loop";
|
||||||
|
constexpr auto kZeroParamName = "zero";
|
||||||
constexpr auto kOneParamName = "one";
|
constexpr auto kOneParamName = "one";
|
||||||
constexpr auto kEpochParamName = "loop_epoch";
|
constexpr auto kEpochParamName = "loop_epoch";
|
||||||
constexpr auto kStreamNeedActivedFirst = "stream_need_active_first";
|
constexpr auto kStreamNeedActivedFirst = "stream_need_active_first";
|
||||||
constexpr uint32_t kSecondStreamSwitchLabel = 2;
|
constexpr uint32_t kSecondStreamSwitchLabel = 2;
|
||||||
enum StreamSwitchKind {
|
|
||||||
kFpBpStreamSwitch = 0,
|
|
||||||
kGetNextStreamSwitch = 1,
|
|
||||||
kEosStreamSwitch = 2,
|
|
||||||
kIndependentStreamSwitch = 3
|
|
||||||
};
|
|
||||||
|
|
||||||
namespace device {
|
namespace device {
|
||||||
class KernelAdjust {
|
class KernelAdjust {
|
||||||
|
@ -71,22 +65,18 @@ class KernelAdjust {
|
||||||
void CreateSwitchOpParameters(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
|
void CreateSwitchOpParameters(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
|
||||||
std::map<std::string, mindspore::ParameterPtr> *switch_loop_input);
|
std::map<std::string, mindspore::ParameterPtr> *switch_loop_input);
|
||||||
CNodePtr CreateStreamSwitchOp(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
|
CNodePtr CreateStreamSwitchOp(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
|
||||||
const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input,
|
const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input);
|
||||||
StreamSwitchKind kind);
|
|
||||||
|
|
||||||
CNodePtr CreatTupleGetItemNode(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr, const CNodePtr &node,
|
CNodePtr CreatTupleGetItemNode(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr, const CNodePtr &node,
|
||||||
size_t output_idx);
|
size_t output_idx);
|
||||||
CNodePtr CreateEndOfSequenceOP(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
|
CNodePtr CreateEndOfSequenceOP(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
|
||||||
const CNodePtr &getnext_cnode);
|
const CNodePtr &getnext_cnode);
|
||||||
CNodePtr CreateStreamAssignAddnOP(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
|
CNodePtr CreateStreamAssignAddnOP(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
|
||||||
const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input,
|
const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input);
|
||||||
bool cur_loop);
|
|
||||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder CreateMngKernelBuilder(const std::vector<std::string> &formats,
|
kernel::KernelBuildInfo::KernelBuildInfoBuilder CreateMngKernelBuilder(const std::vector<std::string> &formats,
|
||||||
const std::vector<TypeId> &type_ids);
|
const std::vector<TypeId> &type_ids);
|
||||||
void LoadSwitchInputs(std::vector<tensor::TensorPtr> *inputs);
|
void LoadSwitchInputs(std::vector<tensor::TensorPtr> *inputs);
|
||||||
void InsertProfilingKernel(const ProfilingTraceInfo &profiling_trace_info,
|
void InsertProfilingKernel(const ProfilingTraceInfo &profiling_trace_info,
|
||||||
NotNull<session::KernelGraph *> kernel_graph_ptr);
|
NotNull<session::KernelGraph *> kernel_graph_ptr);
|
||||||
bool ExitIndependent(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
|
||||||
};
|
};
|
||||||
} // namespace device
|
} // namespace device
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -580,14 +580,6 @@ void KernelRuntime::AssignNodeOutputMem(MemType type, const AnfNodePtr &node, in
|
||||||
MS_LOG(INFO) << "GetNext disable mem_reuse";
|
MS_LOG(INFO) << "GetNext disable mem_reuse";
|
||||||
type = kDynamicMem;
|
type = kDynamicMem;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (node->isa<CNode>()) {
|
|
||||||
bool independent = AnfAlgo::IsIndependentNode(node->cast<CNodePtr>());
|
|
||||||
if (independent && type == kReuseDynamicMem) {
|
|
||||||
MS_LOG(INFO) << "Independent disable mem_reuse";
|
|
||||||
type = kDynamicMem;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
auto kernel_mod = AnfAlgo::GetKernelMod(node);
|
auto kernel_mod = AnfAlgo::GetKernelMod(node);
|
||||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||||
auto output_sizes = kernel_mod->GetOutputSizeList();
|
auto output_sizes = kernel_mod->GetOutputSizeList();
|
||||||
|
|
|
@ -210,7 +210,6 @@ constexpr auto kAttrDataType = "data_type";
|
||||||
constexpr auto kAttrActiveTarget = "active_target";
|
constexpr auto kAttrActiveTarget = "active_target";
|
||||||
constexpr auto kAttrActiveStreamList = "active_stream_list";
|
constexpr auto kAttrActiveStreamList = "active_stream_list";
|
||||||
constexpr auto kAttrTrueBranchStream = "true_branch_stream";
|
constexpr auto kAttrTrueBranchStream = "true_branch_stream";
|
||||||
constexpr auto kAttrStreamSwitchKind = "stream_switch_kind";
|
|
||||||
constexpr auto kAttrEventId = "event_id";
|
constexpr auto kAttrEventId = "event_id";
|
||||||
constexpr auto kAttrDynInput = "dynamic";
|
constexpr auto kAttrDynInput = "dynamic";
|
||||||
constexpr auto kAttrDynInputSizes = "dyn_input_sizes";
|
constexpr auto kAttrDynInputSizes = "dyn_input_sizes";
|
||||||
|
|
Loading…
Reference in New Issue