diff --git a/mindspore/ccsrc/device/kernel_adjust.cc b/mindspore/ccsrc/device/kernel_adjust.cc index f4fe64b4df6..fd0a8eb967b 100644 --- a/mindspore/ccsrc/device/kernel_adjust.cc +++ b/mindspore/ccsrc/device/kernel_adjust.cc @@ -109,6 +109,7 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr return; } MS_EXCEPTION_IF_NULL(kernel_graph_ptr); + bool eos_mode = ConfigManager::GetInstance().iter_num() == INT32_MAX; ReorderGetNext(kernel_graph_ptr); std::map switch_loop_input; CreateSwitchOpParameters(kernel_graph_ptr, &switch_loop_input); @@ -129,12 +130,17 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr } } - auto orders = kernel_graph_ptr->execution_order(); + const std::vector &orders = kernel_graph_ptr->execution_order(); if (orders.empty()) { MS_LOG(EXCEPTION) << "graph execution order is empty"; } std::vector exec_order; + std::vector getnext_active_streams; + std::vector fpbp_active_streams; + CNodePtr getnext_cnode; + uint32_t eos_done_event_id = UINT32_MAX; + // getnext loop process // getnext loop stream switch op CNodePtr getnext_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); @@ -151,6 +157,7 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr exec_order.push_back(node); AnfAlgo::SetStreamId(getnext_stream_id, exec_order[exec_order.size() - 1].get()); if (AnfAlgo::GetCNodeName(node) == kGetNextOpName) { + getnext_cnode = node; break; } } @@ -158,11 +165,52 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr // update getnext loop stream switch true_branch_stream attr AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(getnext_stream_id), getnext_switch_app); - // getnext loop send - uint32_t getnext_event_id = resource_manager.ApplyNewEvent(); - CNodePtr send = CreateSendApplyKernel(kernel_graph_ptr, getnext_event_id); - AnfAlgo::SetStreamId(getnext_stream_id, send.get()); - exec_order.push_back(send); + // getnext loop fpbp start send + uint32_t fpbp_start_event_id = resource_manager.ApplyNewEvent(); + CNodePtr fpbp_start_send = CreateSendApplyKernel(kernel_graph_ptr, fpbp_start_event_id); + AnfAlgo::SetStreamId(getnext_stream_id, fpbp_start_send.get()); + exec_order.push_back(fpbp_start_send); + + if (eos_mode) { + // getnext loop eos start send + uint32_t eos_start_event_id = resource_manager.ApplyNewEvent(); + CNodePtr eos_start_send = CreateSendApplyKernel(kernel_graph_ptr, eos_start_event_id); + AnfAlgo::SetStreamId(getnext_stream_id, eos_start_send.get()); + exec_order.push_back(eos_start_send); + + // End Of Sequence loop process + // eos loop stream switch + CNodePtr eos_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); + MS_EXCEPTION_IF_NULL(eos_switch_app); + uint32_t eos_switch_stream_id = resource_manager.ApplyNewStream(); + AnfAlgo::SetStreamId(eos_switch_stream_id, eos_switch_app.get()); + AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue(true), eos_switch_app); + exec_order.push_back(eos_switch_app); + + // eos loop eos start recv + CNodePtr eos_start_recv = CreateRecvApplyKernel(kernel_graph_ptr, eos_start_event_id); + uint32_t eos_stream_id = resource_manager.ApplyNewStream(); + AnfAlgo::SetStreamId(eos_stream_id, eos_start_recv.get()); + exec_order.push_back(eos_start_recv); + + // update eos loop stream switch true_branch_stream attr + AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(eos_stream_id), eos_switch_app); + + // EndOfSequence op + CNodePtr end_of_sequence_op = CreateEndOfSequenceOP(kernel_graph_ptr, getnext_cnode); + MS_EXCEPTION_IF_NULL(end_of_sequence_op); + AnfAlgo::SetStreamId(eos_stream_id, end_of_sequence_op.get()); + exec_order.push_back(end_of_sequence_op); + + // eos loop eos done send + eos_done_event_id = resource_manager.ApplyNewEvent(); + CNodePtr eos_done_send = CreateSendApplyKernel(kernel_graph_ptr, eos_done_event_id); + AnfAlgo::SetStreamId(eos_stream_id, eos_done_send.get()); + exec_order.push_back(eos_done_send); + + // eos loop stream active + fpbp_active_streams.push_back(eos_switch_stream_id); + } // fpbp loop process // fpbp loop stream switch @@ -173,11 +221,11 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue(true), fpbp_switch_app); exec_order.push_back(fpbp_switch_app); - // fpbp loop recv - CNodePtr recv = CreateRecvApplyKernel(kernel_graph_ptr, getnext_event_id); + // fpbp loop fpbp start recv + CNodePtr fpbp_start_recv = CreateRecvApplyKernel(kernel_graph_ptr, fpbp_start_event_id); uint32_t fpbp_stream_id = resource_manager.ApplyNewStream(); - AnfAlgo::SetStreamId(fpbp_stream_id, recv.get()); - exec_order.push_back(recv); + AnfAlgo::SetStreamId(fpbp_stream_id, fpbp_start_recv.get()); + exec_order.push_back(fpbp_start_recv); // update fpbp loop stream switch true_branch_stream attr AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(fpbp_stream_id), fpbp_switch_app); @@ -190,40 +238,41 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr // fpbp memcpy std::vector memcpy_list; - std::vector before_list; - std::vector after_list; - bool first_memcpy_found = false; + std::vector other_list; CNodePtr cur_cnode = nullptr; for (size_t idx = i + 1; idx < orders.size(); idx++) { cur_cnode = orders[idx]; if (AnfAlgo::HasNodeAttr(kAttrLabelForInsertStreamActive, cur_cnode)) { memcpy_list.emplace_back(cur_cnode); - first_memcpy_found = true; - } else if (first_memcpy_found) { - after_list.emplace_back(cur_cnode); } else { - before_list.emplace_back(cur_cnode); + other_list.emplace_back(cur_cnode); } } - (void)std::copy(before_list.begin(), before_list.end(), std::back_inserter(exec_order)); + (void)std::copy(memcpy_list.begin(), memcpy_list.end(), std::back_inserter(exec_order)); + // fpbp loop eos done recv + if (eos_mode) { + CNodePtr eos_done_recv = CreateRecvApplyKernel(kernel_graph_ptr, eos_done_event_id); + AnfAlgo::SetStreamId(fpbp_stream_id, eos_done_recv.get()); + exec_order.push_back(eos_done_recv); + } + // stream active to activate getnext loop CNodePtr getnext_active_app = CreateStreamActiveOp(kernel_graph_ptr); MS_EXCEPTION_IF_NULL(getnext_active_app); - std::vector getnext_active_streams = {getnext_switch_stream_id}; + getnext_active_streams.push_back(getnext_switch_stream_id); AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(getnext_active_streams), getnext_active_app); exec_order.push_back(getnext_active_app); // fpbp loop other ops - (void)std::copy(after_list.begin(), after_list.end(), std::back_inserter(exec_order)); + (void)std::copy(other_list.begin(), other_list.end(), std::back_inserter(exec_order)); - // stream active to activate fpbp loop + // stream active to activate fpbp loop and eos loop CNodePtr fpbp_active_app = CreateStreamActiveOp(kernel_graph_ptr); MS_EXCEPTION_IF_NULL(fpbp_active_app); - // specific deal for common ctrl stream policy - std::vector fpbp_active_streams = {fpbp_switch_stream_id}; + fpbp_active_streams.push_back(fpbp_switch_stream_id); AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(fpbp_active_streams), fpbp_active_app); exec_order.push_back(fpbp_active_app); @@ -323,6 +372,55 @@ CNodePtr KernelAdjust::CreateStreamActiveOp(const std::shared_ptr &kernel_graph_ptr, + const CNodePtr &node, size_t output_idx) { + auto idx = NewValueNode(SizeToInt(output_idx)); + MS_EXCEPTION_IF_NULL(idx); + auto imm = std::make_shared(SizeToInt(output_idx)); + auto abstract_scalar = std::make_shared(imm); + idx->set_abstract(abstract_scalar); + CNodePtr tuple_getitem = kernel_graph_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx}); + MS_EXCEPTION_IF_NULL(tuple_getitem); + tuple_getitem->set_scope(node->scope()); + std::vector origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); + TypeId origin_type = AnfAlgo::GetOutputInferDataType(node, output_idx); + AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, tuple_getitem.get()); + return tuple_getitem; +} + +CNodePtr KernelAdjust::CreateEndOfSequenceOP(const std::shared_ptr &kernel_graph_ptr, + const CNodePtr &getnext_cnode) { + MS_EXCEPTION_IF_NULL(kernel_graph_ptr); + kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; + selected_kernel_builder.SetInputsFormat({kOpFormat_DEFAULT}); + selected_kernel_builder.SetInputsDeviceType({kNumberTypeUInt8}); + + selected_kernel_builder.SetFusionType(kernel::FusionType::OPAQUE); + selected_kernel_builder.SetProcessor(kernel::Processor::AICPU); + selected_kernel_builder.SetKernelType(KernelType::AICPU_KERNEL); + + selected_kernel_builder.SetOutputsFormat({kOpFormat_DEFAULT}); + selected_kernel_builder.SetOutputsDeviceType({kNumberTypeUInt8}); + // EndOfSequence + auto end_of_sequence = std::make_shared(kEndOfSequence); + std::vector inputs; + inputs.push_back(NewValueNode(end_of_sequence)); + // GetNext output 0 is EndOfSequence's input + auto tuple_get_item = CreatTupleGetItemNode(kernel_graph_ptr, getnext_cnode, 0); + inputs.push_back(tuple_get_item); + CNodePtr end_of_sequence_node = kernel_graph_ptr->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(end_of_sequence_node); + AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), end_of_sequence_node.get()); + std::vector input_names = {"x"}; + ValuePtr input_names_v = MakeValue(input_names); + AnfAlgo::SetNodeAttr("input_names", input_names_v, end_of_sequence_node); + std::vector output_names = {"y"}; + ValuePtr output_names_v = MakeValue(output_names); + AnfAlgo::SetNodeAttr("output_names", output_names_v, end_of_sequence_node); + end_of_sequence_node->set_abstract(tuple_get_item->abstract()); + return end_of_sequence_node; +} + CNodePtr KernelAdjust::CreateStreamAssignAddnOP( const std::shared_ptr &kernel_graph_ptr, const std::map &switch_loop_input) { diff --git a/mindspore/ccsrc/device/kernel_adjust.h b/mindspore/ccsrc/device/kernel_adjust.h index 5dc559408a1..bf3ba2acb25 100644 --- a/mindspore/ccsrc/device/kernel_adjust.h +++ b/mindspore/ccsrc/device/kernel_adjust.h @@ -65,6 +65,10 @@ class KernelAdjust { std::map *switch_loop_input); CNodePtr CreateStreamSwitchOp(const std::shared_ptr &kernel_graph_ptr, const std::map &switch_loop_input); + CNodePtr CreatTupleGetItemNode(const std::shared_ptr &kernel_graph_ptr, const CNodePtr &node, + size_t output_idx); + CNodePtr CreateEndOfSequenceOP(const std::shared_ptr &kernel_graph_ptr, + const CNodePtr &getnext_cnode); CNodePtr CreateStreamAssignAddnOP(const std::shared_ptr &kernel_graph_ptr, const std::map &switch_loop_input); kernel::KernelBuildInfo::KernelBuildInfoBuilder CreateMngKernelBuilder(const std::vector &formats, diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 8d0f729e50c..d10d5830fa4 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -44,6 +44,7 @@ constexpr auto kBNGrad3OpName = "BNGrad3"; constexpr auto kClearZeroOpName = "ClearZero"; constexpr auto kAtomicAddrCleanOpName = "AtomicAddrClean"; constexpr auto kGetNextOpName = "GetNext"; +constexpr auto kEndOfSequence = "EndOfSequence"; constexpr auto kAllReduceOpName = "AllReduce"; constexpr auto kAllGatherOpName = "AllGather"; constexpr auto kHostAllGatherOpName = "HostAllGather"; diff --git a/mindspore/ops/_op_impl/aicpu/__init__.py b/mindspore/ops/_op_impl/aicpu/__init__.py index 9349e10cfff..7b86e47e36b 100644 --- a/mindspore/ops/_op_impl/aicpu/__init__.py +++ b/mindspore/ops/_op_impl/aicpu/__init__.py @@ -29,3 +29,4 @@ from .normal import _normal_aicpu from .ctcloss import _ctcloss_aicpu from .reverse_sequence import _reverse_sequence_aicpu from .crop_and_resize import _crop_and_resize_aicpu +from .end_of_sequence import _end_of_sequence_aicpu diff --git a/mindspore/ops/_op_impl/aicpu/end_of_sequence.py b/mindspore/ops/_op_impl/aicpu/end_of_sequence.py new file mode 100644 index 00000000000..da70cc12f57 --- /dev/null +++ b/mindspore/ops/_op_impl/aicpu/end_of_sequence.py @@ -0,0 +1,30 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""EndOfSequence op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +end_of_sequence_op_info = AiCPURegOp("EndOfSequence") \ + .fusion_type("OPAQUE") \ + .input(0, "x", "required") \ + .output(0, "y", "required") \ + .dtype_format(DataType.U8_Default, DataType.U8_Default) \ + .get_op_info() + + +@op_info_register(end_of_sequence_op_info) +def _end_of_sequence_aicpu(): + """EndOfSequence AiCPU register""" + return