forked from mindspore-Ecosystem/mindspore
!2677 End Of Sequence in vm
Merge pull request !2677 from laiyongqiang/eos_commit
This commit is contained in:
commit
71fd4321c6
|
@ -109,6 +109,7 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph>
|
|||
return;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
|
||||
bool eos_mode = ConfigManager::GetInstance().iter_num() == INT32_MAX;
|
||||
ReorderGetNext(kernel_graph_ptr);
|
||||
std::map<std::string, mindspore::ParameterPtr> switch_loop_input;
|
||||
CreateSwitchOpParameters(kernel_graph_ptr, &switch_loop_input);
|
||||
|
@ -129,12 +130,17 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph>
|
|||
}
|
||||
}
|
||||
|
||||
auto orders = kernel_graph_ptr->execution_order();
|
||||
const std::vector<CNodePtr> &orders = kernel_graph_ptr->execution_order();
|
||||
if (orders.empty()) {
|
||||
MS_LOG(EXCEPTION) << "graph execution order is empty";
|
||||
}
|
||||
|
||||
std::vector<CNodePtr> exec_order;
|
||||
std::vector<uint32_t> getnext_active_streams;
|
||||
std::vector<uint32_t> fpbp_active_streams;
|
||||
CNodePtr getnext_cnode;
|
||||
uint32_t eos_done_event_id = UINT32_MAX;
|
||||
|
||||
// getnext loop process
|
||||
// getnext loop stream switch op
|
||||
CNodePtr getnext_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input);
|
||||
|
@ -151,6 +157,7 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph>
|
|||
exec_order.push_back(node);
|
||||
AnfAlgo::SetStreamId(getnext_stream_id, exec_order[exec_order.size() - 1].get());
|
||||
if (AnfAlgo::GetCNodeName(node) == kGetNextOpName) {
|
||||
getnext_cnode = node;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -158,11 +165,52 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph>
|
|||
// update getnext loop stream switch true_branch_stream attr
|
||||
AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(getnext_stream_id), getnext_switch_app);
|
||||
|
||||
// getnext loop send
|
||||
uint32_t getnext_event_id = resource_manager.ApplyNewEvent();
|
||||
CNodePtr send = CreateSendApplyKernel(kernel_graph_ptr, getnext_event_id);
|
||||
AnfAlgo::SetStreamId(getnext_stream_id, send.get());
|
||||
exec_order.push_back(send);
|
||||
// getnext loop fpbp start send
|
||||
uint32_t fpbp_start_event_id = resource_manager.ApplyNewEvent();
|
||||
CNodePtr fpbp_start_send = CreateSendApplyKernel(kernel_graph_ptr, fpbp_start_event_id);
|
||||
AnfAlgo::SetStreamId(getnext_stream_id, fpbp_start_send.get());
|
||||
exec_order.push_back(fpbp_start_send);
|
||||
|
||||
if (eos_mode) {
|
||||
// getnext loop eos start send
|
||||
uint32_t eos_start_event_id = resource_manager.ApplyNewEvent();
|
||||
CNodePtr eos_start_send = CreateSendApplyKernel(kernel_graph_ptr, eos_start_event_id);
|
||||
AnfAlgo::SetStreamId(getnext_stream_id, eos_start_send.get());
|
||||
exec_order.push_back(eos_start_send);
|
||||
|
||||
// End Of Sequence loop process
|
||||
// eos loop stream switch
|
||||
CNodePtr eos_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input);
|
||||
MS_EXCEPTION_IF_NULL(eos_switch_app);
|
||||
uint32_t eos_switch_stream_id = resource_manager.ApplyNewStream();
|
||||
AnfAlgo::SetStreamId(eos_switch_stream_id, eos_switch_app.get());
|
||||
AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue<bool>(true), eos_switch_app);
|
||||
exec_order.push_back(eos_switch_app);
|
||||
|
||||
// eos loop eos start recv
|
||||
CNodePtr eos_start_recv = CreateRecvApplyKernel(kernel_graph_ptr, eos_start_event_id);
|
||||
uint32_t eos_stream_id = resource_manager.ApplyNewStream();
|
||||
AnfAlgo::SetStreamId(eos_stream_id, eos_start_recv.get());
|
||||
exec_order.push_back(eos_start_recv);
|
||||
|
||||
// update eos loop stream switch true_branch_stream attr
|
||||
AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(eos_stream_id), eos_switch_app);
|
||||
|
||||
// EndOfSequence op
|
||||
CNodePtr end_of_sequence_op = CreateEndOfSequenceOP(kernel_graph_ptr, getnext_cnode);
|
||||
MS_EXCEPTION_IF_NULL(end_of_sequence_op);
|
||||
AnfAlgo::SetStreamId(eos_stream_id, end_of_sequence_op.get());
|
||||
exec_order.push_back(end_of_sequence_op);
|
||||
|
||||
// eos loop eos done send
|
||||
eos_done_event_id = resource_manager.ApplyNewEvent();
|
||||
CNodePtr eos_done_send = CreateSendApplyKernel(kernel_graph_ptr, eos_done_event_id);
|
||||
AnfAlgo::SetStreamId(eos_stream_id, eos_done_send.get());
|
||||
exec_order.push_back(eos_done_send);
|
||||
|
||||
// eos loop stream active
|
||||
fpbp_active_streams.push_back(eos_switch_stream_id);
|
||||
}
|
||||
|
||||
// fpbp loop process
|
||||
// fpbp loop stream switch
|
||||
|
@ -173,11 +221,11 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph>
|
|||
AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue<bool>(true), fpbp_switch_app);
|
||||
exec_order.push_back(fpbp_switch_app);
|
||||
|
||||
// fpbp loop recv
|
||||
CNodePtr recv = CreateRecvApplyKernel(kernel_graph_ptr, getnext_event_id);
|
||||
// fpbp loop fpbp start recv
|
||||
CNodePtr fpbp_start_recv = CreateRecvApplyKernel(kernel_graph_ptr, fpbp_start_event_id);
|
||||
uint32_t fpbp_stream_id = resource_manager.ApplyNewStream();
|
||||
AnfAlgo::SetStreamId(fpbp_stream_id, recv.get());
|
||||
exec_order.push_back(recv);
|
||||
AnfAlgo::SetStreamId(fpbp_stream_id, fpbp_start_recv.get());
|
||||
exec_order.push_back(fpbp_start_recv);
|
||||
|
||||
// update fpbp loop stream switch true_branch_stream attr
|
||||
AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(fpbp_stream_id), fpbp_switch_app);
|
||||
|
@ -190,40 +238,41 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph>
|
|||
|
||||
// fpbp memcpy
|
||||
std::vector<CNodePtr> memcpy_list;
|
||||
std::vector<CNodePtr> before_list;
|
||||
std::vector<CNodePtr> after_list;
|
||||
bool first_memcpy_found = false;
|
||||
std::vector<CNodePtr> other_list;
|
||||
CNodePtr cur_cnode = nullptr;
|
||||
for (size_t idx = i + 1; idx < orders.size(); idx++) {
|
||||
cur_cnode = orders[idx];
|
||||
if (AnfAlgo::HasNodeAttr(kAttrLabelForInsertStreamActive, cur_cnode)) {
|
||||
memcpy_list.emplace_back(cur_cnode);
|
||||
first_memcpy_found = true;
|
||||
} else if (first_memcpy_found) {
|
||||
after_list.emplace_back(cur_cnode);
|
||||
} else {
|
||||
before_list.emplace_back(cur_cnode);
|
||||
other_list.emplace_back(cur_cnode);
|
||||
}
|
||||
}
|
||||
(void)std::copy(before_list.begin(), before_list.end(), std::back_inserter(exec_order));
|
||||
|
||||
(void)std::copy(memcpy_list.begin(), memcpy_list.end(), std::back_inserter(exec_order));
|
||||
|
||||
// fpbp loop eos done recv
|
||||
if (eos_mode) {
|
||||
CNodePtr eos_done_recv = CreateRecvApplyKernel(kernel_graph_ptr, eos_done_event_id);
|
||||
AnfAlgo::SetStreamId(fpbp_stream_id, eos_done_recv.get());
|
||||
exec_order.push_back(eos_done_recv);
|
||||
}
|
||||
|
||||
// stream active to activate getnext loop
|
||||
CNodePtr getnext_active_app = CreateStreamActiveOp(kernel_graph_ptr);
|
||||
MS_EXCEPTION_IF_NULL(getnext_active_app);
|
||||
std::vector<uint32_t> getnext_active_streams = {getnext_switch_stream_id};
|
||||
getnext_active_streams.push_back(getnext_switch_stream_id);
|
||||
AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(getnext_active_streams),
|
||||
getnext_active_app);
|
||||
exec_order.push_back(getnext_active_app);
|
||||
|
||||
// fpbp loop other ops
|
||||
(void)std::copy(after_list.begin(), after_list.end(), std::back_inserter(exec_order));
|
||||
(void)std::copy(other_list.begin(), other_list.end(), std::back_inserter(exec_order));
|
||||
|
||||
// stream active to activate fpbp loop
|
||||
// stream active to activate fpbp loop and eos loop
|
||||
CNodePtr fpbp_active_app = CreateStreamActiveOp(kernel_graph_ptr);
|
||||
MS_EXCEPTION_IF_NULL(fpbp_active_app);
|
||||
// specific deal for common ctrl stream policy
|
||||
std::vector<uint32_t> fpbp_active_streams = {fpbp_switch_stream_id};
|
||||
fpbp_active_streams.push_back(fpbp_switch_stream_id);
|
||||
AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(fpbp_active_streams), fpbp_active_app);
|
||||
exec_order.push_back(fpbp_active_app);
|
||||
|
||||
|
@ -323,6 +372,55 @@ CNodePtr KernelAdjust::CreateStreamActiveOp(const std::shared_ptr<session::Kerne
|
|||
return stream_active_others_app;
|
||||
}
|
||||
|
||||
CNodePtr KernelAdjust::CreatTupleGetItemNode(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
|
||||
const CNodePtr &node, size_t output_idx) {
|
||||
auto idx = NewValueNode(SizeToInt(output_idx));
|
||||
MS_EXCEPTION_IF_NULL(idx);
|
||||
auto imm = std::make_shared<Int32Imm>(SizeToInt(output_idx));
|
||||
auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
|
||||
idx->set_abstract(abstract_scalar);
|
||||
CNodePtr tuple_getitem = kernel_graph_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
|
||||
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
||||
tuple_getitem->set_scope(node->scope());
|
||||
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx);
|
||||
TypeId origin_type = AnfAlgo::GetOutputInferDataType(node, output_idx);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, tuple_getitem.get());
|
||||
return tuple_getitem;
|
||||
}
|
||||
|
||||
CNodePtr KernelAdjust::CreateEndOfSequenceOP(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
|
||||
const CNodePtr &getnext_cnode) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder;
|
||||
selected_kernel_builder.SetInputsFormat({kOpFormat_DEFAULT});
|
||||
selected_kernel_builder.SetInputsDeviceType({kNumberTypeUInt8});
|
||||
|
||||
selected_kernel_builder.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
selected_kernel_builder.SetProcessor(kernel::Processor::AICPU);
|
||||
selected_kernel_builder.SetKernelType(KernelType::AICPU_KERNEL);
|
||||
|
||||
selected_kernel_builder.SetOutputsFormat({kOpFormat_DEFAULT});
|
||||
selected_kernel_builder.SetOutputsDeviceType({kNumberTypeUInt8});
|
||||
// EndOfSequence
|
||||
auto end_of_sequence = std::make_shared<Primitive>(kEndOfSequence);
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.push_back(NewValueNode(end_of_sequence));
|
||||
// GetNext output 0 is EndOfSequence's input
|
||||
auto tuple_get_item = CreatTupleGetItemNode(kernel_graph_ptr, getnext_cnode, 0);
|
||||
inputs.push_back(tuple_get_item);
|
||||
CNodePtr end_of_sequence_node = kernel_graph_ptr->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(end_of_sequence_node);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), end_of_sequence_node.get());
|
||||
std::vector<std::string> input_names = {"x"};
|
||||
ValuePtr input_names_v = MakeValue(input_names);
|
||||
AnfAlgo::SetNodeAttr("input_names", input_names_v, end_of_sequence_node);
|
||||
std::vector<std::string> output_names = {"y"};
|
||||
ValuePtr output_names_v = MakeValue(output_names);
|
||||
AnfAlgo::SetNodeAttr("output_names", output_names_v, end_of_sequence_node);
|
||||
end_of_sequence_node->set_abstract(tuple_get_item->abstract());
|
||||
return end_of_sequence_node;
|
||||
}
|
||||
|
||||
CNodePtr KernelAdjust::CreateStreamAssignAddnOP(
|
||||
const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
|
||||
const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input) {
|
||||
|
|
|
@ -65,6 +65,10 @@ class KernelAdjust {
|
|||
std::map<std::string, mindspore::ParameterPtr> *switch_loop_input);
|
||||
CNodePtr CreateStreamSwitchOp(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
|
||||
const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input);
|
||||
CNodePtr CreatTupleGetItemNode(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr, const CNodePtr &node,
|
||||
size_t output_idx);
|
||||
CNodePtr CreateEndOfSequenceOP(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
|
||||
const CNodePtr &getnext_cnode);
|
||||
CNodePtr CreateStreamAssignAddnOP(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
|
||||
const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input);
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder CreateMngKernelBuilder(const std::vector<std::string> &formats,
|
||||
|
|
|
@ -44,6 +44,7 @@ constexpr auto kBNGrad3OpName = "BNGrad3";
|
|||
constexpr auto kClearZeroOpName = "ClearZero";
|
||||
constexpr auto kAtomicAddrCleanOpName = "AtomicAddrClean";
|
||||
constexpr auto kGetNextOpName = "GetNext";
|
||||
constexpr auto kEndOfSequence = "EndOfSequence";
|
||||
constexpr auto kAllReduceOpName = "AllReduce";
|
||||
constexpr auto kAllGatherOpName = "AllGather";
|
||||
constexpr auto kHostAllGatherOpName = "HostAllGather";
|
||||
|
|
|
@ -29,3 +29,4 @@ from .normal import _normal_aicpu
|
|||
from .ctcloss import _ctcloss_aicpu
|
||||
from .reverse_sequence import _reverse_sequence_aicpu
|
||||
from .crop_and_resize import _crop_and_resize_aicpu
|
||||
from .end_of_sequence import _end_of_sequence_aicpu
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""EndOfSequence op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
end_of_sequence_op_info = AiCPURegOp("EndOfSequence") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(end_of_sequence_op_info)
|
||||
def _end_of_sequence_aicpu():
|
||||
"""EndOfSequence AiCPU register"""
|
||||
return
|
Loading…
Reference in New Issue