From 3d38bcf791d1733fc19c7052f1b5425ad1347ce7 Mon Sep 17 00:00:00 2001 From: zhoufeng Date: Tue, 13 Apr 2021 19:45:19 +0800 Subject: [PATCH] use internal ge runtime Signed-off-by: zhoufeng --- cmake/package.cmake | 1 - mindspore/ccsrc/CMakeLists.txt | 2 +- .../kernel_compiler/aicpu/aicpu_kernel_mod.cc | 8 +- .../akg/ascend/akg_ascend_kernel_mod.cc | 4 +- .../kernel_compiler/ascend_kernel_mod.h | 4 +- .../kernel_compiler/hccl/hccl_kernel.cc | 4 +- .../backend/kernel_compiler/rts/assign.cc | 2 +- .../backend/kernel_compiler/rts/label_goto.cc | 2 +- .../backend/kernel_compiler/rts/label_set.cc | 2 +- .../kernel_compiler/rts/label_switch.cc | 2 +- .../kernel_compiler/rts/memcpy_async.cc | 2 +- .../rts/profiling_kernel_mod.cc | 2 +- .../ccsrc/backend/kernel_compiler/rts/recv.cc | 2 +- .../ccsrc/backend/kernel_compiler/rts/send.cc | 2 +- .../kernel_compiler/rts/stream_active.cc | 2 +- .../kernel_compiler/rts/stream_switch.cc | 2 +- .../kernel_compiler/tbe/tbe_kernel_mod.cc | 4 +- .../tbe/tbe_kernel_parallel_build.cc | 6 +- .../ps/ps_cache/ascend/ascend_ps_cache.cc | 1 + .../ps/ps_cache/ascend/ascend_ps_cache.h | 1 + mindspore/ccsrc/runtime/device/CMakeLists.txt | 6 + .../device/ascend/ascend_kernel_runtime.cc | 50 +- .../device/ascend/ascend_kernel_runtime.h | 4 +- .../device/ascend/ascend_memory_pool.cc | 1 + .../device/ascend/ge_runtime/davinci_model.h | 92 ++++ .../device/ascend/ge_runtime/model_context.h | 59 +++ .../device/ascend/ge_runtime/model_runner.cc | 104 ++++ .../device/ascend/ge_runtime/model_runner.h | 60 +++ .../device/ascend/ge_runtime/runtime_model.cc | 292 +++++++++++ .../device/ascend/ge_runtime/runtime_model.h | 71 +++ .../ascend/ge_runtime/task/aicpu_task.cc | 168 +++++++ .../ascend/ge_runtime/task/aicpu_task.h | 51 ++ .../ge_runtime/task/event_record_task.cc | 54 ++ .../ge_runtime/task/event_record_task.h | 38 ++ .../ascend/ge_runtime/task/event_wait_task.cc | 59 +++ .../ascend/ge_runtime/task/event_wait_task.h | 38 ++ .../ascend/ge_runtime/task/hccl_task.cc | 221 ++++++++ .../device/ascend/ge_runtime/task/hccl_task.h | 68 +++ .../ascend/ge_runtime/task/label_goto_task.cc | 83 +++ .../ascend/ge_runtime/task/label_goto_task.h | 46 ++ .../ascend/ge_runtime/task/label_manager.cc | 116 +++++ .../ascend/ge_runtime/task/label_manager.h | 51 ++ .../ascend/ge_runtime/task/label_set_task.cc | 56 +++ .../ascend/ge_runtime/task/label_set_task.h | 38 ++ .../ge_runtime/task/label_switch_task.cc | 77 +++ .../ge_runtime/task/label_switch_task.h | 43 ++ .../ge_runtime/task/memcpy_async_task.cc | 51 ++ .../ge_runtime/task/memcpy_async_task.h | 37 ++ .../ascend/ge_runtime/task/profiler_task.cc | 47 ++ .../ascend/ge_runtime/task/profiler_task.h | 37 ++ .../ge_runtime/task/stream_active_task.cc | 56 +++ .../ge_runtime/task/stream_active_task.h | 38 ++ .../ge_runtime/task/stream_switch_task.cc | 70 +++ .../ge_runtime/task/stream_switch_task.h | 40 ++ .../device/ascend/ge_runtime/task/task.h | 53 ++ .../ascend/ge_runtime/task/task_factory.h | 84 ++++ .../device/ascend/ge_runtime/task/tbe_task.cc | 97 ++++ .../device/ascend/ge_runtime/task/tbe_task.h | 44 ++ .../device/ascend/ge_runtime/task_info.h | 364 ++++++++++++++ .../device/ascend/tasksink/task_generator.h | 2 +- mindspore/core/utils/log_adapter.h | 58 +-- tests/ut/cpp/CMakeLists.txt | 2 + tests/ut/cpp/device/ge_runtime_test.cc | 473 ++++++++++++++++++ tests/ut/cpp/stub/ge/ge_task_launch_stub.cc | 43 -- tests/ut/cpp/stub/runtime/runtime_stub.cc | 37 +- 65 files changed, 3495 insertions(+), 139 deletions(-) create mode 100644 mindspore/ccsrc/runtime/device/ascend/ge_runtime/davinci_model.h create mode 100644 mindspore/ccsrc/runtime/device/ascend/ge_runtime/model_context.h create mode 100644 mindspore/ccsrc/runtime/device/ascend/ge_runtime/model_runner.cc create mode 100644 mindspore/ccsrc/runtime/device/ascend/ge_runtime/model_runner.h create mode 100644 mindspore/ccsrc/runtime/device/ascend/ge_runtime/runtime_model.cc create mode 100644 mindspore/ccsrc/runtime/device/ascend/ge_runtime/runtime_model.h create mode 100644 mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/aicpu_task.cc create mode 100644 mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/aicpu_task.h create mode 100644 mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/event_record_task.cc create mode 100644 mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/event_record_task.h create mode 100644 mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/event_wait_task.cc create mode 100644 mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/event_wait_task.h create mode 100644 mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/hccl_task.cc create mode 100644 mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/hccl_task.h create mode 100644 mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/label_goto_task.cc create mode 100644 mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/label_goto_task.h create mode 100644 mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/label_manager.cc create mode 100644 mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/label_manager.h create mode 100644 mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/label_set_task.cc create mode 100644 mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/label_set_task.h create mode 100644 mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/label_switch_task.cc create mode 100644 mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/label_switch_task.h create mode 100644 mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/memcpy_async_task.cc create mode 100644 mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/memcpy_async_task.h create mode 100644 mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/profiler_task.cc create mode 100644 mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/profiler_task.h create mode 100644 mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/stream_active_task.cc create mode 100644 mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/stream_active_task.h create mode 100644 mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/stream_switch_task.cc create mode 100644 mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/stream_switch_task.h create mode 100644 mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/task.h create mode 100644 mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/task_factory.h create mode 100644 mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/tbe_task.cc create mode 100644 mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/tbe_task.h create mode 100644 mindspore/ccsrc/runtime/device/ascend/ge_runtime/task_info.h create mode 100644 tests/ut/cpp/device/ge_runtime_test.cc diff --git a/cmake/package.cmake b/cmake/package.cmake index 611fc537b6a..bad26872505 100644 --- a/cmake/package.cmake +++ b/cmake/package.cmake @@ -252,7 +252,6 @@ if(NOT ENABLE_GE) FILES ${CMAKE_BINARY_DIR}/graphengine/metadef/graph/libgraph.so ${CMAKE_BINARY_DIR}/graphengine/ge/common/libge_common.so - ${CMAKE_BINARY_DIR}/graphengine/ge/ge_runtime/libge_runtime.so DESTINATION ${INSTALL_LIB_DIR} COMPONENT mindspore ) diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index 93583b4578a..9a3bdac4684 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -309,7 +309,7 @@ if(ENABLE_D) target_link_options(ms_profile PRIVATE -Wl,-init,common_log_init) target_link_libraries(ms_profile -Wl,--start-group -Wl,--whole-archive ${PROFILING} -Wl,--no-whole-archive mindspore::protobuf -Wl,--end-group) - target_link_libraries(mindspore ge_runtime ${CCE_LIB} ${RUNTIME_LIB} ${TSDCLIENT} ${HCCL} ${DATATRANSFER} + target_link_libraries(mindspore ${CCE_LIB} ${RUNTIME_LIB} ${TSDCLIENT} ${HCCL} ${DATATRANSFER} ${HCCL_ADPTER} ${REGISTER} -Wl,--no-as-needed ${OPTILING} ${HCCL_BUILDER} ${HCCL_RA} ${PLATFORM} ${ACL}) target_link_libraries(mindspore -Wl,--start-group proto_input mindspore::protobuf -Wl,--end-group) diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_mod.cc b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_mod.cc index 8b5f8aecaf8..2a753bbc62b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_mod.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_mod.cc @@ -30,7 +30,7 @@ #include "runtime/device/kernel_runtime.h" #include "runtime/device/ascend/executor/host_dynamic_kernel.h" -using AicpuTaskInfoPtr = std::shared_ptr; +using AicpuTaskInfoPtr = std::shared_ptr; using AicpuDynamicKernel = mindspore::device::ascend::AiCpuDynamicKernel; using HostDynamicKernel = mindspore::device::ascend::HostDynamicKernel; @@ -193,9 +193,9 @@ std::vector AicpuOpKernelMod::GenTask(const std::vector node_name_ = kPack; } - AicpuTaskInfoPtr task_info_ptr = - make_shared(kernel_name_, stream_id, node_so_, node_name_, node_def_str_, - ext_info_, input_data_addrs, output_data_addrs, NeedDump()); + AicpuTaskInfoPtr task_info_ptr = std::make_shared( + kernel_name_, stream_id, node_so_, node_name_, node_def_str_, ext_info_, input_data_addrs, output_data_addrs, + NeedDump()); MS_LOG(INFO) << "AicpuOpKernelMod GenTask end"; return {task_info_ptr}; diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_mod.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_mod.cc index 8fea3eefd65..d39dc938eec 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_mod.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_mod.cc @@ -29,7 +29,7 @@ using std::fstream; using std::map; using std::mutex; using std::string; -using TbeTaskInfoPtr = std::shared_ptr; +using TbeTaskInfoPtr = std::shared_ptr; using tbe::KernelManager; constexpr uint32_t DEFAULT_BLOCK_DIM = 1; /** @@ -118,7 +118,7 @@ std::vector AkgKernelMod::GenTask(const std::vector &in MS_LOG(DEBUG) << "The block_dim is:" << block_dim; - TbeTaskInfoPtr task_info_ptr = make_shared( + TbeTaskInfoPtr task_info_ptr = std::make_shared( kernel_name_, stream_id, stub_func, block_dim, args, args_size, sm_desc, binary, binary_size, meta_data, input_data_addrs, output_data_addrs, workspace_addrs, NeedDump()); return {task_info_ptr}; diff --git a/mindspore/ccsrc/backend/kernel_compiler/ascend_kernel_mod.h b/mindspore/ccsrc/backend/kernel_compiler/ascend_kernel_mod.h index 47cdd7dfff5..231539d7871 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/ascend_kernel_mod.h +++ b/mindspore/ccsrc/backend/kernel_compiler/ascend_kernel_mod.h @@ -19,11 +19,11 @@ #include #include -#include "framework/ge_runtime/task_info.h" +#include "runtime/device/ascend/ge_runtime/task_info.h" #include "backend/kernel_compiler/kernel.h" #include "debug/data_dump/dump_json_parser.h" -using TaskInfoPtr = std::shared_ptr; +using TaskInfoPtr = std::shared_ptr; namespace mindspore { namespace kernel { class AscendKernelMod : public KernelMod { diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc index 36c36ce4abc..89abbeb043f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc @@ -24,8 +24,8 @@ #include "runtime/device/ascend/executor/hccl_dynamic_kernel.h" #include "runtime/hccl_adapter/hccl_adapter.h" -using HcclTaskInfoPtr = std::shared_ptr; -using ge::model_runner::HcclTaskInfo; +using HcclTaskInfoPtr = std::shared_ptr; +using mindspore::ge::model_runner::HcclTaskInfo; namespace { static std::map kMsOpNameToHcomHcclType = { diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/assign.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/assign.cc index 7e98fb5994f..d4b12350463 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/assign.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/assign.cc @@ -18,7 +18,7 @@ #include #include "runtime/mem.h" -using ge::model_runner::MemcpyAsyncTaskInfo; +using mindspore::ge::model_runner::MemcpyAsyncTaskInfo; using MemcpyAsyncTaskInfoPtr = std::shared_ptr; namespace mindspore { diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/label_goto.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/label_goto.cc index eb651b34667..82964cd4e31 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/label_goto.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/label_goto.cc @@ -20,7 +20,7 @@ #include "framework/ge_runtime/task_info.h" #include "backend/session/anf_runtime_algorithm.h" -using ge::model_runner::LabelGotoTaskInfo; +using mindspore::ge::model_runner::LabelGotoTaskInfo; using LabelGotoTaskInfoPtr = std::shared_ptr; namespace mindspore { diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/label_set.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/label_set.cc index 69f0bfd49b4..8dc947c0182 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/label_set.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/label_set.cc @@ -20,7 +20,7 @@ #include "framework/ge_runtime/task_info.h" #include "backend/session/anf_runtime_algorithm.h" -using ge::model_runner::LabelSetTaskInfo; +using mindspore::ge::model_runner::LabelSetTaskInfo; using LabelSetTaskInfoPtr = std::shared_ptr; namespace mindspore { diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/label_switch.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/label_switch.cc index 5f2be915d79..29df657d2e6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/label_switch.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/label_switch.cc @@ -21,7 +21,7 @@ #include "framework/ge_runtime/task_info.h" #include "backend/session/anf_runtime_algorithm.h" -using ge::model_runner::LabelSwitchTaskInfo; +using mindspore::ge::model_runner::LabelSwitchTaskInfo; using LabelSwitchTaskInfoPtr = std::shared_ptr; namespace mindspore { diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.cc index f05c302f93c..a5914498077 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.cc @@ -25,7 +25,7 @@ #include "runtime/device/kernel_runtime.h" #include "runtime/device/ascend/executor/rts/memcpy_rts_dynamic_kernel.h" -using ge::model_runner::MemcpyAsyncTaskInfo; +using mindspore::ge::model_runner::MemcpyAsyncTaskInfo; using MemcpyAsyncTaskInfoPtr = std::shared_ptr; using AddressPtrList = std::vector; using mindspore::device::ascend::MemcpyRtsDynamicKernel; diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/profiling_kernel_mod.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/profiling_kernel_mod.cc index cbbfba380c5..ad8e1beeeb1 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/profiling_kernel_mod.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/profiling_kernel_mod.cc @@ -23,7 +23,7 @@ #include "backend/session/anf_runtime_algorithm.h" #include "runtime/device/ascend/executor/rts/profiling_rts_dynamic_kernel.h" -using ProfilerTraceTaskInfo = ge::model_runner::ProfilerTraceTaskInfo; +using ProfilerTraceTaskInfo = mindspore::ge::model_runner::ProfilerTraceTaskInfo; using mindspore::device::ascend::ProfilingRtsDynamicKernel; using mindspore::device::ascend::ProfilingUtils; diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/recv.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/recv.cc index 1661ecf256d..6a2ca1f9b96 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/recv.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/recv.cc @@ -23,7 +23,7 @@ namespace mindspore { namespace kernel { -using ge::model_runner::EventWaitTaskInfo; +using mindspore::ge::model_runner::EventWaitTaskInfo; using EventWaitTaskInfoPtr = std::shared_ptr; RecvKernel::RecvKernel() { event_id_ = 0; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/send.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/send.cc index 53081f47918..ef5e678f3c8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/send.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/send.cc @@ -20,7 +20,7 @@ #include "framework/ge_runtime/task_info.h" #include "backend/session/anf_runtime_algorithm.h" -using ge::model_runner::EventRecordTaskInfo; +using mindspore::ge::model_runner::EventRecordTaskInfo; using EventRecordTaskInfoPtr = std::shared_ptr; namespace mindspore { diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/stream_active.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/stream_active.cc index 77b80346d15..f59238a07d0 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/stream_active.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/stream_active.cc @@ -20,7 +20,7 @@ #include "framework/ge_runtime/task_info.h" #include "backend/session/anf_runtime_algorithm.h" -using ge::model_runner::StreamActiveTaskInfo; +using mindspore::ge::model_runner::StreamActiveTaskInfo; using StreamActiveTaskInfoPtr = std::shared_ptr; namespace mindspore { diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/stream_switch.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/stream_switch.cc index e5947487546..e85b6bb97d0 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/stream_switch.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/stream_switch.cc @@ -21,7 +21,7 @@ #include "framework/ge_runtime/task_info.h" #include "backend/session/anf_runtime_algorithm.h" -using ge::model_runner::StreamSwitchTaskInfo; +using mindspore::ge::model_runner::StreamSwitchTaskInfo; using StreamSwitchTaskInfoPtr = std::shared_ptr; namespace mindspore { diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_mod.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_mod.cc index bde79f0b400..a49076e1b8b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_mod.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_mod.cc @@ -24,7 +24,7 @@ namespace mindspore { namespace kernel { -using TbeTaskInfoPtr = std::shared_ptr; +using TbeTaskInfoPtr = std::shared_ptr; using tbe::KernelManager; using AddressPtrList = std::vector; bool TbeKernelMod::Launch(const std::vector &inputs, @@ -102,7 +102,7 @@ std::vector TbeKernelMod::GenTask(const std::vector &in MS_LOG(INFO) << "block_dim is:" << block_dim_; - TbeTaskInfoPtr task_info_ptr = make_shared( + TbeTaskInfoPtr task_info_ptr = std::make_shared( kernel_name_, stream_id, stub_func, block_dim_, args, 0, sm_desc, nullptr, 0, meta_data, input_data_addrs, output_data_addrs, workspace_addrs, NeedDump()); return {task_info_ptr}; diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.cc index 6c5d6616f24..4afa9d9ffb8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.cc @@ -36,7 +36,7 @@ using mindspore::kernel::tbe::TbeUtils; bool TbeOpParallelBuild(const std::vector &anf_nodes) { auto build_manger = std::make_shared(); MS_EXCEPTION_IF_NULL(build_manger); - static set processed_kernel; + static std::set processed_kernel; auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); auto tune_mode = context_ptr->get_param(MS_CTX_TUNE_MODE); @@ -259,8 +259,8 @@ bool ParallelBuildManager::SearchInCache(const std::string &json_name, const std } KernelModPtr ParallelBuildManager::GenKernelMod(const string &json_name, const string &processor, - const vector &input_size_list, - const vector &output_size_list, + const std::vector &input_size_list, + const std::vector &output_size_list, const mindspore::kernel::KernelPackPtr &kernel_pack) const { MS_EXCEPTION_IF_NULL(kernel_pack); auto kernel_json_info = kernel_pack->kernel_json_info(); diff --git a/mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc b/mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc index 463787d00d1..5dfd80e80e2 100644 --- a/mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc +++ b/mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc @@ -27,6 +27,7 @@ #include "proto/tensor_shape.pb.h" #include "proto/attr.pb.h" #include "proto/node_def.pb.h" +#include "runtime/rt.h" using mindspore::kernel::Address; using AddressPtr = std::shared_ptr
; diff --git a/mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.h b/mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.h index 966da84b00e..8470f6f78ab 100644 --- a/mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.h +++ b/mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.h @@ -24,6 +24,7 @@ #include "ps/ps_cache/ps_cache_basic.h" #include "backend/kernel_compiler/aicpu/aicpu_kernel_mod.h" #include "ir/dtype.h" +#include "runtime/base.h" namespace mindspore { namespace ps { diff --git a/mindspore/ccsrc/runtime/device/CMakeLists.txt b/mindspore/ccsrc/runtime/device/CMakeLists.txt index c7d87544357..68534970005 100644 --- a/mindspore/ccsrc/runtime/device/CMakeLists.txt +++ b/mindspore/ccsrc/runtime/device/CMakeLists.txt @@ -79,3 +79,9 @@ list(REMOVE_ITEM D_SRC_LIST "ascend/profiling/profiling_callback_register.cc") set_property(SOURCE ${DEVICE_SRC_LIST} ${D_SRC_LIST} ${CPU_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) add_library(_mindspore_runtime_device_obj OBJECT ${DEVICE_SRC_LIST} ${D_SRC_LIST} ${CPU_SRC_LIST} ${TDT_SRC_LIST}) +if(ENABLE_D) + file(GLOB_RECURSE GE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ascend/ge_runtime/*.cc") + set_property(SOURCE ${GE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_GE) + target_include_directories(_mindspore_runtime_device_obj PRIVATE ${CMAKE_BINARY_DIR}/proto/ge) + add_dependencies(_mindspore_runtime_device_obj graph) +endif() diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc index 7fe3d4e60a3..be0308a379c 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc @@ -28,9 +28,9 @@ #include "utils/mpi/mpi_config.h" #include "runtime/device/ascend/profiling/profiling_manager.h" #include "common/trans.h" -#include "runtime/context.h" +#include "runtime/rt.h" #include "runtime/device/ascend/ascend_stream_assign.h" -#include "framework/ge_runtime/model_runner.h" +#include "runtime/device/ascend/ge_runtime/model_runner.h" #include "runtime/device/ascend/tasksink/task_generator.h" #include "backend/session/anf_runtime_algorithm.h" #include "runtime/device/ascend/profiling/profiling_utils.h" @@ -40,7 +40,6 @@ #include "toolchain/adx_datadump_server.h" #include "utils/trace_base.h" #include "graphengine/inc/external/acl/error_codes/rt_error_codes.h" -#include "utils/runtime_error_codes.h" #include "debug/anf_ir_dump.h" #ifdef MEM_REUSE_DEBUG #include "backend/optimizer/mem_reuse/mem_reuse_checker.h" @@ -61,10 +60,10 @@ using mindspore::dataset::TdtHandle; #include "debug/rdr/running_data_recorder.h" #endif -using ge::model_runner::ModelRunner; using mindspore::device::ascend::ProfilingManager; using mindspore::device::ascend::ProfilingUtils; using mindspore::device::ascend::tasksink::TaskGenerator; +using mindspore::ge::model_runner::ModelRunner; using mindspore::kernel::tbe::TbeUtils; using std::vector; @@ -158,10 +157,7 @@ void AscendKernelRuntime::ClearGraphModelMap() { graph_kernel_events_map_.clear(); for (auto &iter : graph_model_map_) { MS_LOG(INFO) << "Ge UnloadModel " << iter.first; - auto ret = ModelRunner::Instance().UnloadModel(iter.first); - if (!ret) { - MS_LOG(ERROR) << "UnloadModel failed"; - } + ModelRunner::Instance().UnloadModel(iter.first); } } @@ -194,10 +190,7 @@ void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std MS_LOG(DEBUG) << "Clear graph:" << graph_id << " runtime resource"; if (auto model_iter = graph_model_map_.find(graph_id); model_iter != graph_model_map_.end()) { MS_LOG(DEBUG) << "Ge UnloadModel " << graph_id; - auto ret = ModelRunner::Instance().UnloadModel(graph_id); - if (!ret) { - MS_LOG(ERROR) << "UnloadModel failed"; - } + ModelRunner::Instance().UnloadModel(graph_id); graph_model_map_.erase(model_iter); } else { MS_LOG(DEBUG) << "GraphId:" << graph_id << " not found"; @@ -482,10 +475,9 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { << ", total label num:" << graph->label_num() << ", wait_active_stream_list size:" << wait_active_stream_list.size() << ", force_copy_stream_list size:" << force_copy_stream_list.size(); - std::vector> empty_list; auto model = std::make_shared( - task_info_list, empty_list, empty_list, empty_list, empty_list, wait_active_stream_list, force_copy_stream_list, 0, - 0, 0, 0, 0, 0, resource_manager.get_cur_stream_num(), graph->label_num(), resource_manager.get_cur_event_num(), 0); + task_info_list, wait_active_stream_list, force_copy_stream_list, 0, 0, 0, 0, 0, 0, + resource_manager.get_cur_stream_num(), graph->label_num(), resource_manager.get_cur_event_num(), 0); auto ret = graph_model_map_.insert(std::make_pair(graph->graph_id(), model)); if (!ret.second) { MS_LOG(EXCEPTION) << "Duplicate GraphId! Please check in ascend_session."; @@ -514,24 +506,20 @@ bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) { return false; } - std::shared_ptr listener; MS_LOG(INFO) << "LoadDavinciModel mode_id:" << model_iter->first; - bool status = - ModelRunner::Instance().LoadDavinciModel(device_id_, 0, model_iter->first, model_iter->second, listener); - if (!status) { - MS_LOG(EXCEPTION) << "Load Model Failed"; - } + ModelRunner::Instance().LoadDavinciModel(device_id_, 0, model_iter->first, model_iter->second); std::function model_handle = std::bind(&ModelRunner::GetModelHandle, &ModelRunner::Instance(), model_iter->first); DistributeDebugTask(NOT_NULL(graph), NOT_NULL(model_handle)); - status = ModelRunner::Instance().DistributeTask(model_iter->first); - if (!status) { + try { + ModelRunner::Instance().DistributeTask(model_iter->first); + } catch (const std::exception &e) { #ifdef ENABLE_DUMP_IR mindspore::RDR::TriggerAll(); #endif - MS_LOG(EXCEPTION) << "Distribute Task Failed"; + MS_LOG(EXCEPTION) << "Distribute Task Failed, error: " << e.what(); } if (ProfilingManager::GetInstance().IsProfiling()) { @@ -542,10 +530,7 @@ bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) { LaunchDataDump(graph->graph_id()); - if (!ModelRunner::Instance().LoadModelComplete(model_iter->first)) { - MS_LOG(ERROR) << "Call ge runtime LoadModelComplete failed"; - return false; - } + ModelRunner::Instance().LoadModelComplete(model_iter->first); return true; } @@ -730,8 +715,6 @@ bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - ge::InputData input_tensors = ge::InputData(); - ge::OutputData *output_tensors = nullptr; if (GraphWithEmptyTaskList(graph)) { MS_LOG(WARNING) << "RunTask end, no task info found"; return true; @@ -742,8 +725,9 @@ bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) { return false; } - bool status = ModelRunner::Instance().RunModel(graph->graph_id(), input_tensors, output_tensors); - if (!status) { + try { + ModelRunner::Instance().RunModel(graph->graph_id()); + } catch (const std::exception &) { DumpTaskExceptionInfo(graph); std::string file_name = "task_error_debug" + std::to_string(graph->graph_id()) + ".ir"; auto graph_tmp = std::make_shared(*graph); @@ -988,7 +972,7 @@ void AscendKernelRuntime::KernelLaunchProfiling(const std::string &kernel_name) } uint64_t AscendKernelRuntime::GetAvailableMemMaxSize() const { - auto ascend_mem_manager = dynamic_pointer_cast(mem_manager_); + auto ascend_mem_manager = std::dynamic_pointer_cast(mem_manager_); return ascend_mem_manager->GetDeviceMemSize(); } diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h index cec00ae9f23..9d4d1bec66c 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h @@ -25,15 +25,15 @@ #include #include "runtime/device/kernel_runtime.h" #include "runtime/context.h" -#include "framework/ge_runtime/davinci_model.h" +#include "runtime/device/ascend/ge_runtime/davinci_model.h" #include "runtime/device/kernel_runtime_manager.h" #include "backend/session/session_basic.h" #include "runtime/device/ascend/dump/data_dumper.h" -using ge::model_runner::TaskInfo; using std::unordered_map; using std::vector; namespace mindspore::device::ascend { +using ge::model_runner::TaskInfo; class AscendKernelRuntime : public KernelRuntime { public: AscendKernelRuntime() = default; diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_memory_pool.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_pool.cc index 861e45f1583..1a7fc59f46a 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_memory_pool.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_pool.cc @@ -16,6 +16,7 @@ #include #include "runtime/device/ascend/ascend_memory_pool.h" +#include "runtime/mem.h" #include "runtime/device/ascend/ascend_kernel_runtime.h" #include "utils/log_adapter.h" diff --git a/mindspore/ccsrc/runtime/device/ascend/ge_runtime/davinci_model.h b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/davinci_model.h new file mode 100644 index 00000000000..8808ce38bee --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/davinci_model.h @@ -0,0 +1,92 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_DAVINCI_MODEL_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_DAVINCI_MODEL_H_ + +#include +#include +#include "runtime/device/ascend/ge_runtime/task_info.h" + +namespace mindspore::ge::model_runner { +class DavinciModel { + public: + DavinciModel(const std::vector> &task_info_list, + const std::vector &wait_active_stream_list, + const std::vector &force_copy_stream_list, uint64_t mem_size = 0, uint64_t weight_size = 0, + uint64_t var_size = 0, uintptr_t logic_mem_base = 0, uintptr_t logic_weight_base = 0, + uintptr_t logic_var_base = 0, uint32_t stream_num = 0, uint32_t batch_num = 0, uint32_t event_num = 0, + int32_t priority = 0) + : task_info_list_(task_info_list), + wait_active_stream_list_(wait_active_stream_list), + force_copy_stream_list_(force_copy_stream_list), + mem_size_(mem_size), + weight_size_(weight_size), + var_size_(var_size), + logic_mem_base_(logic_mem_base), + logic_weight_base_(logic_weight_base), + logic_var_base_(logic_var_base), + stream_num_(stream_num), + batch_num_(batch_num), + event_num_(event_num), + priority_(priority) {} + ~DavinciModel() {} + + uint64_t GetMemSize() const { return mem_size_; } + uint64_t GetWeightSize() const { return weight_size_; } + uint64_t GetVarSize() const { return var_size_; } + + uintptr_t GetLogicMemBase() const { return logic_mem_base_; } + uintptr_t GetLogicWeightBase() const { return logic_weight_base_; } + uintptr_t GetLogicVarBase() const { return logic_var_base_; } + + uint32_t GetStreamNum() const { return stream_num_; } + uint32_t GetBatchNum() const { return batch_num_; } + uint32_t GetEventNum() const { return event_num_; } + + const std::vector &GetWaitActiveStreams() const { return wait_active_stream_list_; } + const std::vector &GetForceCopyStreams() const { return force_copy_stream_list_; } + + int32_t GetPriority() const { return priority_; } + + const std::vector> &GetTaskInfoList() const { return task_info_list_; } + + private: + std::vector> task_info_list_; + + std::vector wait_active_stream_list_; + std::vector force_copy_stream_list_; + + uint64_t mem_size_; + uint64_t weight_size_; + uint64_t var_size_; + + uintptr_t logic_mem_base_; + uintptr_t logic_weight_base_; + uintptr_t logic_var_base_; + + uint32_t stream_num_; + uint32_t batch_num_; + uint32_t event_num_; + + int32_t priority_; + + // Disable to copy constructor and assignment operator + DavinciModel &operator=(const DavinciModel &) = delete; + DavinciModel(const DavinciModel &) = delete; +}; +} // namespace mindspore::ge::model_runner +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_DAVINCI_MODEL_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/ge_runtime/model_context.h b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/model_context.h new file mode 100644 index 00000000000..d9e148eec0c --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/model_context.h @@ -0,0 +1,59 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_MODEL_CONTEXT_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_MODEL_CONTEXT_H_ + +#include +#include "runtime/rt_model.h" + +namespace mindspore::ge::model_runner { +class ModelContext { + public: + ModelContext(uint32_t device_id, uint64_t session_id, int32_t priority, rtModel_t rt_model_handle, + rtStream_t rt_model_stream, const std::vector &stream_list, + const std::vector &label_list, const std::vector &event_list) + : device_id_(device_id), + session_id_(session_id), + priority_(priority), + rt_model_handle_(rt_model_handle), + rt_model_stream_(rt_model_stream), + stream_list_(stream_list), + label_list_(label_list), + event_list_(event_list) {} + ~ModelContext() {} + + uint64_t device_id() const { return device_id_; } + uint64_t session_id() const { return session_id_; } + int32_t priority() const { return priority_; } + const rtModel_t &rt_model_handle() const { return rt_model_handle_; } + const rtStream_t &rt_model_stream() const { return rt_model_stream_; } + const std::vector &stream_list() const { return stream_list_; } + const std::vector &label_list() const { return label_list_; } + const std::vector &event_list() const { return event_list_; } + + private: + uint32_t device_id_; + uint64_t session_id_; + int32_t priority_; + rtModel_t rt_model_handle_; + rtStream_t rt_model_stream_; + std::vector stream_list_; + std::vector label_list_; + std::vector event_list_; +}; +} // namespace mindspore::ge::model_runner +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_MODEL_CONTEXT_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/ge_runtime/model_runner.cc b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/model_runner.cc new file mode 100644 index 00000000000..1cf4c858002 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/model_runner.cc @@ -0,0 +1,104 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/ge_runtime/model_runner.h" +#include "runtime/device/ascend/ge_runtime/runtime_model.h" +#include "runtime/device/ascend/ge_runtime/davinci_model.h" +#include "mindspore/core/utils/log_adapter.h" + +namespace mindspore::ge::model_runner { +ModelRunner &ModelRunner::Instance() { + static ModelRunner instance; // Guaranteed to be destroyed. + return instance; +} + +void ModelRunner::LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint32_t model_id, + const std::shared_ptr &davinci_model) { + std::shared_ptr model = std::make_shared(); + model->Load(device_id, session_id, davinci_model); + runtime_models_[model_id] = model; +} + +void ModelRunner::DistributeTask(uint32_t model_id) { + auto model_iter = runtime_models_.find(model_id); + if (model_iter == runtime_models_.end()) { + MS_LOG(EXCEPTION) << "Model id " << model_id << " not found."; + } + MS_EXCEPTION_IF_NULL(model_iter->second); + model_iter->second->DistributeTask(); +} + +void ModelRunner::LoadModelComplete(uint32_t model_id) { + auto model_iter = runtime_models_.find(model_id); + if (model_iter == runtime_models_.end()) { + MS_LOG(EXCEPTION) << "Model id " << model_id << " not found."; + } + MS_EXCEPTION_IF_NULL(model_iter->second); + model_iter->second->LoadComplete(); +} + +const std::vector &ModelRunner::GetTaskIdList(uint32_t model_id) const { + auto model_iter = runtime_models_.find(model_id); + if (model_iter == runtime_models_.end()) { + MS_LOG(EXCEPTION) << "Model id " << model_id << " not found."; + } + MS_EXCEPTION_IF_NULL(model_iter->second); + return model_iter->second->GetTaskIdList(); +} + +const std::vector &ModelRunner::GetStreamIdList(uint32_t model_id) const { + auto model_iter = runtime_models_.find(model_id); + if (model_iter == runtime_models_.end()) { + MS_LOG(EXCEPTION) << "Model id " << model_id << " not found."; + } + MS_EXCEPTION_IF_NULL(model_iter->second); + return model_iter->second->GetStreamIdList(); +} + +const std::map> &ModelRunner::GetRuntimeInfoMap(uint32_t model_id) const { + auto model_iter = runtime_models_.find(model_id); + if (model_iter == runtime_models_.end()) { + MS_LOG(EXCEPTION) << "Model id " << model_id << " not found."; + } + MS_EXCEPTION_IF_NULL(model_iter->second); + return model_iter->second->GetRuntimeInfoMap(); +} + +void *ModelRunner::GetModelHandle(uint32_t model_id) const { + auto model_iter = runtime_models_.find(model_id); + if (model_iter == runtime_models_.end()) { + MS_LOG(EXCEPTION) << "Model id " << model_id << " not found."; + } + MS_EXCEPTION_IF_NULL(model_iter->second); + return model_iter->second->GetModelHandle(); +} + +void ModelRunner::UnloadModel(uint32_t model_id) { + auto iter = runtime_models_.find(model_id); + if (iter != runtime_models_.end()) { + (void)runtime_models_.erase(iter); + } +} + +void ModelRunner::RunModel(uint32_t model_id) { + auto model_iter = runtime_models_.find(model_id); + if (model_iter == runtime_models_.end()) { + MS_LOG(EXCEPTION) << "Model id " << model_id << " not found."; + } + MS_EXCEPTION_IF_NULL(model_iter->second); + model_iter->second->Run(); +} +} // namespace mindspore::ge::model_runner diff --git a/mindspore/ccsrc/runtime/device/ascend/ge_runtime/model_runner.h b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/model_runner.h new file mode 100644 index 00000000000..b11c85d0bd9 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/model_runner.h @@ -0,0 +1,60 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_MODEL_RUNNER_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_MODEL_RUNNER_H_ + +#include +#include +#include +#include +#include +#include "runtime/device/ascend/ge_runtime/davinci_model.h" + +namespace mindspore::ge::model_runner { +class RuntimeModel; +using RuntimeInfo = std::tuple; +class ModelRunner { + public: + static ModelRunner &Instance(); + + void LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint32_t model_id, + const std::shared_ptr &davinci_model); + + void DistributeTask(uint32_t model_id); + + void LoadModelComplete(uint32_t model_id); + + const std::vector &GetTaskIdList(uint32_t model_id) const; + + const std::vector &GetStreamIdList(uint32_t model_id) const; + + const std::map> &GetRuntimeInfoMap(uint32_t model_id) const; + + void *GetModelHandle(uint32_t model_id) const; + + void UnloadModel(uint32_t model_id); + + void RunModel(uint32_t model_id); + + private: + ModelRunner() = default; + ~ModelRunner() = default; + + std::map> runtime_models_; +}; +} // namespace mindspore::ge::model_runner +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_MODEL_RUNNER_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/ge_runtime/runtime_model.cc b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/runtime_model.cc new file mode 100644 index 00000000000..e7533de5881 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/runtime_model.cc @@ -0,0 +1,292 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/ge_runtime/runtime_model.h" +#include +#include "runtime/kernel.h" +#include "runtime/rt_model.h" +#include "graphengine/inc/external/runtime/rt_error_codes.h" +#include "runtime/device/ascend/ge_runtime/model_context.h" +#include "runtime/device/ascend/ge_runtime/task/task.h" +#include "runtime/device/ascend/ge_runtime/task/task_factory.h" +#include "mindspore/core/utils/log_adapter.h" + +namespace mindspore::ge::model_runner { +RuntimeModel::~RuntimeModel() { + MS_LOG(INFO) << "RuntimeModel destructor start."; + + // Unbind rtModel from all task related streams + RtModelUnbindStream(); + + // Release task first, hccl task hold stream + task_list_.clear(); + + // Release all task related streams + RtStreamDestory(); + + // Release rtlabel resource + RtLabelDestory(); + + // Release rtEvent resourece + RtEventDestory(); + + MS_LOG(INFO) << "Do RtModelDestroy"; + // Release all rt_model + RtModelDestory(); +} + +void RuntimeModel::InitStream(const std::shared_ptr &davinci_model) { + MS_EXCEPTION_IF_NULL(davinci_model); + + std::set wait_active_streams; + std::set force_copy_streams; + + for (const auto &stream_id : davinci_model->GetWaitActiveStreams()) { + MS_LOG(INFO) << "Stream id " << stream_id << " is wait active stream."; + (void)wait_active_streams.insert(stream_id); + } + + for (const auto &stream_id : davinci_model->GetForceCopyStreams()) { + MS_LOG(INFO) << "Stream id " << stream_id << " is force copy stream."; + (void)force_copy_streams.insert(stream_id); + } + + MS_LOG(INFO) << "Total stream num " << davinci_model->GetStreamNum(); + for (uint32_t i = 0; i < davinci_model->GetStreamNum(); ++i) { + rtStream_t stream = nullptr; + uint32_t flag = (force_copy_streams.find(i) != force_copy_streams.end()) + ? (RT_STREAM_PERSISTENT | RT_STREAM_FORCE_COPY) + : (RT_STREAM_PERSISTENT); + + rtError_t rt_ret = rtStreamCreateWithFlags(&stream, davinci_model->GetPriority(), flag); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rt api rtStreamCreate failed, ret: " << std::hex << rt_ret; + } + + MS_LOG(INFO) << "rtStreamCreateWithFlags end."; + stream_list_.emplace_back(stream); + + // Bind rt_model_handle_ to all task related streams + flag = (wait_active_streams.find(i) != wait_active_streams.end()) ? (static_cast(RT_INVALID_FLAG)) + : (static_cast(RT_HEAD_STREAM)); + rt_ret = rtModelBindStream(rt_model_handle_, stream, flag); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rt api rtModelBindStream failed, ret: " << std::hex << rt_ret; + } + MS_LOG(INFO) << "stream index: " << i << ", stream: " << std::hex << stream; + } +} + +void RuntimeModel::InitEvent(uint32_t event_num) { + MS_LOG(INFO) << "Event number: " << event_num; + for (uint32_t i = 0; i < event_num; ++i) { + rtEvent_t rt_event; + rtError_t rt_ret = rtEventCreate(&rt_event); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rt api rtEventCreate failed, ret: " << std::hex << rt_ret; + } + event_list_.push_back(rt_event); + } +} + +void RuntimeModel::InitLabel(const std::shared_ptr &davinci_model) { + MS_LOG(INFO) << "Label number: " << davinci_model->GetBatchNum(); + label_list_.resize(davinci_model->GetBatchNum()); + for (auto &task_info : davinci_model->GetTaskInfoList()) { + MS_EXCEPTION_IF_NULL(task_info); + + if (task_info->type() != TaskInfoType::LABEL_SET) { + continue; + } + auto label_set_task_info = std::static_pointer_cast(task_info); + + if (label_set_task_info->stream_id() >= stream_list_.size()) { + MS_LOG(EXCEPTION) << "Invalid stream id " << label_set_task_info->stream_id() << " total stream num " + << stream_list_.size(); + } + + rtLabel_t rt_label = nullptr; + rtError_t rt_ret = rtLabelCreateEx(&rt_label, stream_list_[label_set_task_info->stream_id()]); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rt api rtLabelCreate failed, ret: " << std::hex << rt_ret; + } + label_list_[label_set_task_info->label_id()] = rt_label; + } +} + +void RuntimeModel::InitResource(const std::shared_ptr &davinci_model) { + MS_LOG(INFO) << "InitResource start"; + MS_EXCEPTION_IF_NULL(davinci_model); + + rtError_t rt_ret = rtModelCreate(&rt_model_handle_, 0); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rt api rtModelCreate failed, ret: " << std::hex << rt_ret; + } + + // Create rtStream for rt_model_handle_ + rt_ret = rtStreamCreate(&rt_model_stream_, davinci_model->GetPriority()); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rt api rtStreamCreate failed, ret: " << std::hex << rt_ret; + } + MS_LOG(INFO) << "rtStreamCreate end"; + + InitStream(davinci_model); + InitEvent(davinci_model->GetEventNum()); + InitLabel(davinci_model); + + MS_LOG(INFO) << "InitResource success"; +} + +void RuntimeModel::GenerateTask(uint32_t device_id, uint64_t session_id, + const std::shared_ptr &davinci_model) { + MS_LOG(INFO) << "GenerateTask start."; + MS_EXCEPTION_IF_NULL(davinci_model); + auto task_infos = davinci_model->GetTaskInfoList(); + ModelContext model_context(device_id, session_id, davinci_model->GetPriority(), rt_model_handle_, rt_model_stream_, + stream_list_, label_list_, event_list_); + for (auto &task_info : task_infos) { + auto task = TaskFactory::GetInstance().Create(model_context, task_info); + task_list_.push_back(task); + } + MS_LOG(INFO) << "GenerateTask success."; +} + +void RuntimeModel::LoadComplete() { + uint32_t task_id = 0; + uint32_t stream_id = 0; + auto rt_ret = rtModelGetTaskId(rt_model_handle_, &task_id, &stream_id); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rt api rtModelGetTaskId failed, ret: " << std::hex << rt_ret; + } + task_id_list_.push_back(task_id); + stream_id_list_.push_back(stream_id); + + rt_ret = rtModelLoadComplete(rt_model_handle_); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rt api rtModelLoadComplete failed, ret: " << std::hex << rt_ret; + } +} + +void RuntimeModel::Load(uint32_t device_id, uint64_t session_id, const std::shared_ptr &davinci_model) { + InitResource(davinci_model); + GenerateTask(device_id, session_id, davinci_model); +} + +void RuntimeModel::DistributeTask() { + MS_LOG(INFO) << "DistributeTask start."; + for (auto &task : task_list_) { + MS_EXCEPTION_IF_NULL(task); + task->Distribute(); + + uint32_t task_id = 0; + uint32_t stream_id = 0; + rtError_t rt_ret = rtModelGetTaskId(rt_model_handle_, &task_id, &stream_id); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rt api rtModelGetTaskId failed, ret: " << std::hex << rt_ret; + } + task_id_list_.push_back(task_id); + stream_id_list_.push_back(stream_id); + if (task->Args() != nullptr) { + std::shared_ptr runtime_tuple = std::make_shared(task_id, stream_id, task->Args()); + auto emplace_ret = runtime_info_map_.emplace(task->task_name(), runtime_tuple); + if (!emplace_ret.second) { + MS_LOG(WARNING) << "Task name exist: " << task->task_name(); + } + } + } + if (task_list_.empty()) { + MS_LOG(EXCEPTION) << "Task list is empty"; + } + + MS_LOG(INFO) << "DistributeTask success."; +} + +void RuntimeModel::Run() { + MS_LOG(INFO) << "Davinci task run start."; + rtError_t ret = rtModelExecute(rt_model_handle_, rt_model_stream_, 0); + if (ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rt api rtModelLoadComplete failed, ret: " << std::hex << ret; + } + + MS_LOG(INFO) << "Run rtModelExecute success, start to rtStreamSynchronize."; + ret = rtStreamSynchronize(rt_model_stream_); + if (ret != RT_ERROR_NONE) { + if (ret == ACL_ERROR_RT_END_OF_SEQUENCE) { + MS_LOG(INFO) << "Model stream ACL_ERROR_RT_END_OF_SEQUENCE signal received."; + return; + } + MS_LOG(EXCEPTION) << "Call rt api rtStreamSynchronize failed, ret: " << std::hex << ret; + } + + MS_LOG(INFO) << "Davinci task run success."; +} + +void RuntimeModel::RtModelUnbindStream() noexcept { + for (size_t i = 0; i < stream_list_.size(); i++) { + if (rtModelUnbindStream(rt_model_handle_, stream_list_[i]) != RT_ERROR_NONE) { + MS_LOG(ERROR) << "Unbind stream from model failed! Index: " << i; + return; + } + } +} + +void RuntimeModel::RtStreamDestory() noexcept { + if (rtStreamDestroy(rt_model_stream_) != RT_ERROR_NONE) { + MS_LOG(ERROR) << "Destroy stream for rt_model failed!"; + return; + } + + for (size_t i = 0; i < stream_list_.size(); i++) { + if (rtStreamDestroy(stream_list_[i]) != RT_ERROR_NONE) { + MS_LOG(ERROR) << "Destroy stream failed! Index: " << i; + return; + } + } +} + +void RuntimeModel::RtLabelDestory() noexcept { + for (size_t i = 0; i < label_list_.size(); i++) { + if (label_list_[i] == nullptr) { + continue; + } + if (rtLabelDestroy(label_list_[i]) != RT_ERROR_NONE) { + MS_LOG(ERROR) << "Destroy label failed! Index: " << i; + return; + } + } +} + +void RuntimeModel::RtModelDestory() noexcept { + rtError_t ret = rtModelDestroy(rt_model_handle_); + if (ret != RT_ERROR_NONE) { + MS_LOG(ERROR) << "Call rt api rtModelDestroy failed, ret: " << std::hex << ret; + return; + } +} + +void RuntimeModel::RtEventDestory() noexcept { + for (size_t i = 0; i < event_list_.size(); i++) { + if (rtEventDestroy(event_list_[i]) != RT_ERROR_NONE) { + MS_LOG(ERROR) << "Destroy event failed! Index: " << i; + return; + } + } +} + +const std::vector &RuntimeModel::GetTaskIdList() const { return task_id_list_; } + +const std::vector &RuntimeModel::GetStreamIdList() const { return stream_id_list_; } +} // namespace mindspore::ge::model_runner diff --git a/mindspore/ccsrc/runtime/device/ascend/ge_runtime/runtime_model.h b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/runtime_model.h new file mode 100644 index 00000000000..3831f4994df --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/runtime_model.h @@ -0,0 +1,71 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_RUNTIME_MODEL_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_RUNTIME_MODEL_H_ +#include +#include +#include +#include +#include +#include "runtime/base.h" +#include "runtime/rt_model.h" +#include "runtime/device/ascend/ge_runtime/davinci_model.h" + +namespace mindspore::ge::model_runner { +using RuntimeInfo = std::tuple; +class Task; +class RuntimeModel { + public: + RuntimeModel() = default; + ~RuntimeModel(); + + void Load(uint32_t device_id, uint64_t session_id, const std::shared_ptr &davinci_model); + void DistributeTask(); + void LoadComplete(); + const std::vector &GetTaskIdList() const; + const std::vector &GetStreamIdList() const; + const std::map> &GetRuntimeInfoMap() const { return runtime_info_map_; } + rtModel_t GetModelHandle() const { return rt_model_handle_; } + void Run(); + + private: + void InitResource(const std::shared_ptr &davinci_model); + void GenerateTask(uint32_t device_id, uint64_t session_id, const std::shared_ptr &davinci_model); + void InitStream(const std::shared_ptr &davinci_model); + void InitEvent(uint32_t event_num); + void InitLabel(const std::shared_ptr &davinci_model); + void RtModelUnbindStream() noexcept; + void RtStreamDestory() noexcept; + void RtModelDestory() noexcept; + void RtLabelDestory() noexcept; + void RtEventDestory() noexcept; + + rtModel_t rt_model_handle_{}; + rtStream_t rt_model_stream_{}; + + std::vector stream_list_{}; + std::vector label_list_{}; + std::vector event_list_{}; + + std::vector> task_list_{}; + + std::vector task_id_list_{}; + std::vector stream_id_list_{}; + std::map> runtime_info_map_; +}; +} // namespace mindspore::ge::model_runner +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_RUNTIME_MODEL_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/aicpu_task.cc b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/aicpu_task.cc new file mode 100644 index 00000000000..3e0f23dcdc1 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/aicpu_task.cc @@ -0,0 +1,168 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/ge_runtime/task/aicpu_task.h" +#include +#include "runtime/mem.h" +#include "runtime/kernel.h" +#include "runtime/device/ascend/ge_runtime/task/task_factory.h" +#include "aicpu/common/aicpu_task_struct.h" + +namespace mindspore::ge::model_runner { +AicpuTask::AicpuTask(const ModelContext &model_context, const std::shared_ptr &task_info) + : TaskRepeater(model_context, task_info), + task_info_(task_info), + stream_(nullptr), + args_(nullptr), + ext_info_(nullptr), + input_output_addr_(nullptr) { + MS_EXCEPTION_IF_NULL(task_info_); + + auto stream_list = model_context.stream_list(); + if (stream_list.size() == 1) { + stream_ = stream_list[0]; + } else if (stream_list.size() > task_info_->stream_id()) { + stream_ = stream_list[task_info_->stream_id()]; + } else { + MS_LOG(EXCEPTION) << "Index: " << task_info_->stream_id() << " >= stream_list.size(): " << stream_list.size(); + } +} + +AicpuTask::~AicpuTask() { + ReleaseRtMem(&args_); + ReleaseRtMem(&ext_info_); +} + +void AicpuTask::Distribute() { + MS_LOG(INFO) << "InitAicpuTask start."; + std::vector io_addrs; + io_addrs.insert(io_addrs.end(), task_info_->input_data_addrs().begin(), task_info_->input_data_addrs().end()); + io_addrs.insert(io_addrs.end(), task_info_->output_data_addrs().begin(), task_info_->output_data_addrs().end()); + auto io_addrs_num = static_cast(io_addrs.size()); + auto io_addrs_size = static_cast(io_addrs_num * sizeof(void *)); + constexpr uint32_t io_addr_offset = sizeof(aicpu::AicpuParamHead); + uint32_t node_def_len_offset = io_addr_offset + io_addrs_size; + uint32_t node_def_addr_offset = node_def_len_offset + sizeof(uint32_t); + uint32_t args_size = sizeof(aicpu::AicpuParamHead) + io_addrs_size + + static_cast(task_info_->node_def().size()) + sizeof(uint32_t); + + // Malloc device memory for args + rtError_t rt_ret = rtMalloc(&args_, args_size, RT_MEMORY_HBM); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rt api rtMalloc failed, ret: " << std::hex << rt_ret; + } + + SetAicpuParamHead(args_size, io_addrs_num); + SetInputOutputAddrs(io_addrs, io_addr_offset); + SetNodeDef(node_def_len_offset, node_def_addr_offset); + + // for data dump + input_output_addr_ = reinterpret_cast(reinterpret_cast(args_) + io_addr_offset); + auto dump_flag = task_info_->dump_flag() ? RT_KERNEL_DUMPFLAG : RT_KERNEL_DEFAULT; + + MS_LOG(INFO) << "Distribute AicpuTask start, args_size = " << args_size << ", io_addrs_num =" << io_addrs_num + << ", so_name = " << task_info_->so_name() << ", kernel_name = " << task_info_->kernel_name() + << ", dump_flag = " << dump_flag; + rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast(task_info_->so_name().data()), + reinterpret_cast(task_info_->kernel_name().data()), 1, args_, + args_size, nullptr, stream_, dump_flag); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rt api rtCpuKernelLaunchWithFlag failed, ret: " << std::hex << rt_ret; + } + + MS_LOG(INFO) << "Distribute AicpuTask end."; +} + +void AicpuTask::ReleaseRtMem(void **ptr) noexcept { + if (ptr == nullptr || *ptr == nullptr) { + return; + } + + rtError_t rt_ret = rtFree(*ptr); + if (rt_ret != RT_ERROR_NONE) { + return; + } + *ptr = nullptr; +} + +void AicpuTask::SetAicpuParamHead(uint32_t args_size, uint32_t io_addrs_num) { + aicpu::AicpuParamHead aicpu_param_head; + aicpu_param_head.length = args_size; + aicpu_param_head.ioAddrNum = io_addrs_num; + + const auto &ext_info = task_info_->ext_info(); + uint32_t ext_size = ext_info.size(); + if (ext_info.empty()) { + aicpu_param_head.extInfoLength = 0; + aicpu_param_head.extInfoAddr = 0; + } else { + rtError_t flag = rtMalloc(&ext_info_, ext_size, RT_MEMORY_HBM); + if (flag != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rt api rtMalloc failed, ret: " << std::hex << flag; + } + + flag = rtMemcpy(ext_info_, ext_size, const_cast(reinterpret_cast(ext_info.data())), ext_size, + RT_MEMCPY_HOST_TO_DEVICE); + if (flag != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rt api rtMemcpy failed, ret: " << std::hex << flag; + } + + MS_LOG(INFO) << "ext info size: " << ext_size; + aicpu_param_head.extInfoLength = ext_size; + aicpu_param_head.extInfoAddr = reinterpret_cast(ext_info_); + } + + // Memcpy AicpuParamHead + auto rt_ret = rtMemcpy(args_, sizeof(aicpu::AicpuParamHead), reinterpret_cast(&aicpu_param_head), + sizeof(aicpu::AicpuParamHead), RT_MEMCPY_HOST_TO_DEVICE); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rt api rtMemcpy failed, ret: " << std::hex << rt_ret; + } +} + +void AicpuTask::SetInputOutputAddrs(const std::vector &io_addrs, uint32_t io_addr_offset) { + // Memcpy io addrs + if (!io_addrs.empty()) { + auto rt_ret = rtMemcpy(reinterpret_cast(reinterpret_cast(args_) + io_addr_offset), + static_cast(io_addrs.size()) * sizeof(void *), io_addrs.data(), + static_cast(io_addrs.size()) * sizeof(void *), RT_MEMCPY_HOST_TO_DEVICE); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rt api rtMemcpy failed, ret: " << std::hex << rt_ret; + } + } +} + +void AicpuTask::SetNodeDef(uint32_t node_def_len_offset, uint32_t node_def_addr_offset) { + // Memcpy node def + auto size = task_info_->node_def().size(); + auto rt_ret = + rtMemcpy(reinterpret_cast(reinterpret_cast(args_) + node_def_len_offset), sizeof(uint32_t), + reinterpret_cast(&size), sizeof(uint32_t), RT_MEMCPY_HOST_TO_DEVICE); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rt api rtMemcpy failed, ret: " << std::hex << rt_ret; + } + + // Memcpy node def + rt_ret = rtMemcpy(reinterpret_cast(reinterpret_cast(args_) + node_def_addr_offset), + task_info_->node_def().size(), reinterpret_cast(task_info_->node_def().data()), + task_info_->node_def().size(), RT_MEMCPY_HOST_TO_DEVICE); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rt api rtMemcpy failed, ret: " << std::hex << rt_ret; + } +} + +REGISTER_TASK(TaskInfoType::AICPU, AicpuTask, AicpuTaskInfo); +} // namespace mindspore::ge::model_runner diff --git a/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/aicpu_task.h b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/aicpu_task.h new file mode 100644 index 00000000000..4c7cded24b9 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/aicpu_task.h @@ -0,0 +1,51 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_AICPU_TASK_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_AICPU_TASK_H_ + +#include +#include +#include +#include "runtime/device/ascend/ge_runtime/task/task.h" + +namespace mindspore::ge::model_runner { +class AicpuTask : public TaskRepeater { + public: + AicpuTask(const ModelContext &model_context, const std::shared_ptr &task_info); + + ~AicpuTask() override; + + void Distribute() override; + + void *Args() override { return input_output_addr_; } + + std::string task_name() const override { return task_info_->op_name(); } + + private: + static void ReleaseRtMem(void **ptr) noexcept; + void SetAicpuParamHead(uint32_t args_size, uint32_t io_addrs_num); + void SetInputOutputAddrs(const std::vector &io_addrs, uint32_t io_addr_offset); + void SetNodeDef(uint32_t node_def_len_offset, uint32_t node_def_addr_offset); + + std::shared_ptr task_info_; + void *stream_; + void *args_; + void *ext_info_; + void *input_output_addr_; +}; +} // namespace mindspore::ge::model_runner +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_AICPU_TASK_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/event_record_task.cc b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/event_record_task.cc new file mode 100644 index 00000000000..cc8313088be --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/event_record_task.cc @@ -0,0 +1,54 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/ge_runtime/task/event_record_task.h" +#include "runtime/kernel.h" +#include "runtime/device/ascend/ge_runtime/task/task_factory.h" + +namespace mindspore::ge::model_runner { +EventRecordTask::EventRecordTask(const ModelContext &model_context, + const std::shared_ptr &task_info) + : TaskRepeater(model_context, task_info), + task_info_(task_info), + stream_(nullptr), + event_(nullptr) { + MS_EXCEPTION_IF_NULL(task_info_); + auto stream_list = model_context.stream_list(); + auto event_list = model_context.event_list(); + uint32_t stream_id = task_info_->stream_id(); + uint32_t event_id = task_info_->event_id(); + if (stream_id >= stream_list.size() || event_id >= event_list.size()) { + MS_LOG(EXCEPTION) << "stream_list size: " << stream_list.size() << ", stream_id: " << stream_id + << ", event_list size: " << event_list.size() << ", event_id: " << event_id; + } + stream_ = stream_list[stream_id]; + event_ = event_list[event_id]; +} + +EventRecordTask::~EventRecordTask() {} + +void EventRecordTask::Distribute() { + MS_LOG(INFO) << "EventRecordTask Distribute start, stream: " << stream_ << ", event: " << event_ + << ", stream_id: " << task_info_->stream_id() << ", event_id: " << task_info_->event_id(); + rtError_t rt_ret = rtEventRecord(event_, stream_); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rt api rtEventRecord failed, ret: " << std::hex << rt_ret; + } + MS_LOG(INFO) << "Distribute end."; +} + +REGISTER_TASK(TaskInfoType::EVENT_RECORD, EventRecordTask, EventRecordTaskInfo); +} // namespace mindspore::ge::model_runner diff --git a/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/event_record_task.h b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/event_record_task.h new file mode 100644 index 00000000000..71084eb8d3b --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/event_record_task.h @@ -0,0 +1,38 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_EVENT_RECORD_TASK_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_EVENT_RECORD_TASK_H_ + +#include +#include "runtime/device/ascend/ge_runtime/task/task.h" + +namespace mindspore::ge::model_runner { +class EventRecordTask : public TaskRepeater { + public: + EventRecordTask(const ModelContext &model_context, const std::shared_ptr &task_info); + + ~EventRecordTask() override; + + void Distribute() override; + + private: + std::shared_ptr task_info_; + rtStream_t stream_; + rtEvent_t event_; +}; +} // namespace mindspore::ge::model_runner +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_EVENT_RECORD_TASK_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/event_wait_task.cc b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/event_wait_task.cc new file mode 100644 index 00000000000..d26617f5770 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/event_wait_task.cc @@ -0,0 +1,59 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/ge_runtime/task/event_wait_task.h" +#include "runtime/kernel.h" +#include "runtime/device/ascend/ge_runtime/task/task_factory.h" + +namespace mindspore::ge::model_runner { +EventWaitTask::EventWaitTask(const ModelContext &model_context, const std::shared_ptr &task_info) + : TaskRepeater(model_context, task_info), + task_info_(task_info), + stream_(nullptr), + event_(nullptr) { + MS_EXCEPTION_IF_NULL(task_info_); + auto stream_list = model_context.stream_list(); + auto event_list = model_context.event_list(); + uint32_t stream_id = task_info_->stream_id(); + uint32_t event_id = task_info_->event_id(); + if (stream_id >= stream_list.size() || event_id >= event_list.size()) { + MS_LOG(EXCEPTION) << "stream_list size: " << stream_list.size() << ", stream_id: " << stream_id + << ", event_list size: " << event_list.size() << ", event_id: " << event_id; + } + stream_ = stream_list[stream_id]; + event_ = event_list[event_id]; +} + +EventWaitTask::~EventWaitTask() {} + +void EventWaitTask::Distribute() { + MS_LOG(INFO) << "EventWaitTask Distribute start, stream: " << stream_ << ", event: " << event_ + << ", stream_id: " << task_info_->stream_id() << ", event_id: " << task_info_->event_id(); + + rtError_t rt_ret = rtStreamWaitEvent(stream_, event_); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rt api rtStreamWaitEvent failed, ret: " << std::hex << rt_ret; + } + + rt_ret = rtEventReset(event_, stream_); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rt api rtEventReset failed, ret: " << std::hex << rt_ret; + } + MS_LOG(INFO) << "Distribute end."; +} + +REGISTER_TASK(TaskInfoType::EVENT_WAIT, EventWaitTask, EventWaitTaskInfo); +} // namespace mindspore::ge::model_runner diff --git a/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/event_wait_task.h b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/event_wait_task.h new file mode 100644 index 00000000000..2809a10578e --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/event_wait_task.h @@ -0,0 +1,38 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_EVENT_WAIT_TASK_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_EVENT_WAIT_TASK_H_ + +#include +#include "runtime/device/ascend/ge_runtime/task/task.h" + +namespace mindspore::ge::model_runner { +class EventWaitTask : public TaskRepeater { + public: + EventWaitTask(const ModelContext &model_context, const std::shared_ptr &task_info); + + ~EventWaitTask() override; + + void Distribute() override; + + private: + std::shared_ptr task_info_; + rtStream_t stream_; + rtEvent_t event_; +}; +} // namespace mindspore::ge::model_runner +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_EVENT_WAIT_TASK_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/hccl_task.cc b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/hccl_task.cc new file mode 100644 index 00000000000..39bfaf080f3 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/hccl_task.cc @@ -0,0 +1,221 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/ge_runtime/task/hccl_task.h" +#include +#include "runtime/device/ascend/ge_runtime/task/task_factory.h" +#include "common/opskernel/ops_kernel_info_store.h" +#include "common/opskernel/ge_task_info.h" + +namespace mindspore::ge::model_runner { +std::map>>> + HcclTask::model_stream_mapping_; +std::mutex HcclTask::model_stream_mapping_mutex_; + +HcclTask::HcclTask(const ModelContext &model_context, const std::shared_ptr &task_info) + : TaskRepeater(model_context, task_info), + task_info_(task_info), + stream_(nullptr), + workspace_mem_(nullptr), + rt_model_handle_(nullptr), + priority_(0), + secondary_stream_list_() { + MS_EXCEPTION_IF_NULL(task_info_); + + priority_ = model_context.priority(); + rt_model_handle_ = model_context.rt_model_handle(); + auto stream_list = model_context.stream_list(); + + if (stream_list.size() == 1) { + stream_ = stream_list[0]; + } else if (stream_list.size() > task_info_->stream_id()) { + stream_ = stream_list[task_info_->stream_id()]; + } else { + MS_LOG(EXCEPTION) << "Index: " << task_info_->stream_id() << " >= stream_list.size(): " << stream_list.size(); + } +} + +HcclTask::~HcclTask() {} + +void HcclTask::Distribute() { + // Ops kernel info store + // Get privateDef and opsKernelStorePtr + MS_LOG(INFO) << "Distribute hccl task start."; + void *ops_kernel_store = task_info_->ops_kernel_store(); + ::ge::OpsKernelInfoStore *ops_kernel_info_store = reinterpret_cast<::ge::OpsKernelInfoStore *>(ops_kernel_store); + MS_EXCEPTION_IF_NULL(ops_kernel_info_store); + + char *private_def = reinterpret_cast(const_cast(task_info_->private_def().data())); + auto private_def_len = static_cast(task_info_->private_def().size()); + MS_LOG(INFO) << "The first address of the custom info, privateDef= " << private_def; + SetSecondaryStream(); + + if (task_info_->workspace_size() > 0) { + workspace_mem_ = task_info_->workspace_addr(); + } + + ::ge::GETaskInfo ge_task; + ge_task.id = 0; + ge_task.type = static_cast(RT_MODEL_TASK_HCCL); + ge_task.stream = stream_; + + ge_task.kernelHcclInfo = std::vector<::ge::GETaskKernelHcclInfo>(1); + ge_task.kernelHcclInfo[0].hccl_type = task_info_->hccl_type(); + ge_task.kernelHcclInfo[0].inputDataAddr = task_info_->input_data_addr(); + ge_task.kernelHcclInfo[0].outputDataAddr = task_info_->output_data_addr(); + ge_task.kernelHcclInfo[0].workSpaceAddr = workspace_mem_; + ge_task.kernelHcclInfo[0].workSpaceMemSize = task_info_->workspace_size(); + ge_task.kernelHcclInfo[0].count = task_info_->count(); + ge_task.kernelHcclInfo[0].dataType = static_cast(task_info_->data_type()); + ge_task.kernelHcclInfo[0].opType = static_cast(task_info_->op_type()); + ge_task.kernelHcclInfo[0].rootId = task_info_->root_id(); + + std::vector secondary_stream_list; + std::transform(secondary_stream_list_.begin(), secondary_stream_list_.end(), + std::back_inserter(secondary_stream_list), + [](const std::shared_ptr &stream) -> rtStream_t { return stream->GetStream(); }); + ge_task.kernelHcclInfo[0].hcclStreamList = secondary_stream_list; + + ge_task.privateDef = private_def; + ge_task.privateDefLen = private_def_len; + ge_task.opsKernelStorePtr = ops_kernel_store; + + MS_LOG(INFO) << "Begin to call function LoadTask in hccl."; + auto result = ops_kernel_info_store->LoadTask(ge_task); + // tagHcclResult::HCCL_SUCCESS is 0 + if (result != 0) { + MS_LOG(EXCEPTION) << "davinci_model : load task fail, return ret: " << result; + } + + MS_LOG(INFO) << "Call function LoadTask end."; +} + +void HcclTask::SetSecondaryStream() { + const uint32_t master_stream_id = task_info_->stream_id(); + const int64_t hccl_secondary_stream_num = task_info_->hccl_stream_num(); + std::lock_guard lock(model_stream_mapping_mutex_); + + // no model, create all secondary stream + auto model_iter = model_stream_mapping_.find(rt_model_handle_); + if (model_iter == model_stream_mapping_.end()) { + MS_LOG(INFO) << "Need to create map for rt_model_handle_: " << rt_model_handle_ << " with new mainstream " + << master_stream_id; + CreateStream(hccl_secondary_stream_num, master_stream_id); + MS_LOG(INFO) << "Initialize hccl secondary stream success, hccl_secondary_stream_num=" << hccl_secondary_stream_num; + return; + } + + // has model, but no secondary stream before, create all secondary stream + auto &master_secondary_stream_map = model_iter->second; + auto iter = master_secondary_stream_map.find(master_stream_id); + if (iter == master_secondary_stream_map.end()) { + MS_LOG(INFO) << "Need to create secondary stream for " << task_info_->op_name() << " with new mainstream " + << master_stream_id; + CreateStream(hccl_secondary_stream_num, master_stream_id); + MS_LOG(INFO) << "Initialize hccl secondary stream success, hccl_secondary_stream_num=" << hccl_secondary_stream_num; + return; + } + + // has model, has secondary stream, but number is not enough to be reuse + std::vector> &secondary_stream_vec = iter->second; + if (static_cast(hccl_secondary_stream_num) > secondary_stream_vec.size()) { + size_t created_stream_num = secondary_stream_vec.size(); + auto need_to_create_num = hccl_secondary_stream_num - created_stream_num; + MS_LOG(INFO) << "Need to reuse " << secondary_stream_vec.size() << " secondary stream and create " + << need_to_create_num << " new secondary stream."; + for (size_t i = 0; i < secondary_stream_vec.size(); ++i) { + secondary_stream_list_.push_back(GetSecondaryStream(&secondary_stream_vec, i)); + } + CreateStream(need_to_create_num, master_stream_id); + MS_LOG(INFO) << "Initialize hccl secondary stream success, hccl_secondary_stream_num=" << hccl_secondary_stream_num; + return; + } + + // all can be reuse + MS_LOG(INFO) << "Number of secondary stream " << hccl_secondary_stream_num << " is enough to be reused."; + for (int64_t i = 0; i < hccl_secondary_stream_num; ++i) { + secondary_stream_list_.push_back(GetSecondaryStream(&secondary_stream_vec, i)); + } + MS_LOG(INFO) << "Initialize hccl secondary stream success, hccl_secondary_stream_num = " << hccl_secondary_stream_num; +} + +void HcclTask::CreateStream(int64_t stream_num, int64_t master_stream_id) { + MS_LOG(INFO) << "Start to create " << stream_num << " hccl secondary stream."; + for (int64_t i = 0; i < stream_num; ++i) { + rtStream_t stream = nullptr; + CreateStream(rt_model_handle_, &stream); + auto shared_stream = std::make_shared(rt_model_handle_, stream); + SaveHcclSecondaryStream(master_stream_id, shared_stream); + secondary_stream_list_.push_back(shared_stream); + } + MS_LOG(INFO) << "CreateStream success."; +} + +void HcclTask::CreateStream(rtModel_t model, rtStream_t *stream) const { + MS_EXCEPTION_IF_NULL(stream); + + rtError_t rt_ret = rtStreamCreateWithFlags(stream, priority_, RT_STREAM_PERSISTENT | RT_STREAM_FORCE_COPY); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rt api rtEventRecord failed, ret: " << std::hex << rt_ret; + } + // Create secondary stream, inactive by default, activated by hccl + rt_ret = rtModelBindStream(model, *stream, RT_MODEL_WAIT_ACTIVE_STREAM); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rt api rtEventRecord failed, ret: " << std::hex << rt_ret; + } +} + +void HcclTask::SaveHcclSecondaryStream(int64_t master_stream_id, const std::shared_ptr &stream) { + if (model_stream_mapping_.find(rt_model_handle_) == model_stream_mapping_.end()) { + model_stream_mapping_.emplace(rt_model_handle_, std::map>>()); + } + std::map>> &master_secondary_stream_map = + model_stream_mapping_.at(rt_model_handle_); + master_secondary_stream_map[master_stream_id].emplace_back(stream); +} + +HcclTask::StreamGuard::~StreamGuard() { + rtError_t rt_ret = rtModelUnbindStream(model_, stream_); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(ERROR) << "Call rt api rtModelUnbindStream failed, ret: " << std::hex << rt_ret; + return; + } + + rt_ret = rtStreamDestroy(stream_); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(ERROR) << "Call rt api rtStreamDestroy failed, ret: " << std::hex << rt_ret; + return; + } +} + +std::shared_ptr HcclTask::GetSecondaryStream( + std::vector> *secondary_streams, size_t index) { + MS_EXCEPTION_IF_NULL(secondary_streams); + if (index >= secondary_streams->size()) { + MS_LOG(EXCEPTION) << "Invalid stream index " << index << ", secondary streams size " << secondary_streams->size(); + } + auto stream = secondary_streams->at(index).lock(); + if (stream == nullptr) { + rtStream_t new_stream = nullptr; + CreateStream(rt_model_handle_, &new_stream); + stream = std::make_shared(rt_model_handle_, new_stream); + (*secondary_streams)[index] = stream; + } + return stream; +} + +REGISTER_TASK(TaskInfoType::HCCL, HcclTask, HcclTaskInfo); +} // namespace mindspore::ge::model_runner diff --git a/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/hccl_task.h b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/hccl_task.h new file mode 100644 index 00000000000..406750320d3 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/hccl_task.h @@ -0,0 +1,68 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_HCCL_TASK_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_HCCL_TASK_H_ + +#include +#include +#include +#include +#include +#include "runtime/device/ascend/ge_runtime/task/task.h" + +namespace mindspore::ge::model_runner { +class HcclTask : public TaskRepeater { + public: + HcclTask(const ModelContext &model_context, const std::shared_ptr &task_info); + + ~HcclTask() override; + + void Distribute() override; + + private: + class StreamGuard; + void SetSecondaryStream(); + void CreateStream(int64_t stream_num, int64_t master_stream_id); + void CreateStream(rtModel_t model, rtStream_t *stream) const; + void SaveHcclSecondaryStream(int64_t master_stream_id, const std::shared_ptr &stream); + std::shared_ptr GetSecondaryStream(std::vector> *secondary_streams, + size_t index); + + std::shared_ptr task_info_; + void *stream_; + void *workspace_mem_; + rtModel_t rt_model_handle_; + int32_t priority_; + std::vector> secondary_stream_list_; + + // map>> + static std::map>>> model_stream_mapping_; + static std::mutex model_stream_mapping_mutex_; +}; + +class HcclTask::StreamGuard { + public: + StreamGuard(rtModel_t model, rtStream_t stream) : model_(model), stream_(stream) {} + ~StreamGuard(); + rtStream_t GetStream() const { return stream_; } + + private: + rtModel_t model_; + rtStream_t stream_; +}; +} // namespace mindspore::ge::model_runner +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_HCCL_TASK_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/label_goto_task.cc b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/label_goto_task.cc new file mode 100644 index 00000000000..1e6cc36c093 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/label_goto_task.cc @@ -0,0 +1,83 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/ge_runtime/task/label_goto_task.h" +#include "runtime/mem.h" +#include "runtime/device/ascend/ge_runtime/task/task_factory.h" + +namespace mindspore::ge::model_runner { +LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr &task_info) + : TaskRepeater(model_context, task_info), + task_info_(task_info), + stream_(nullptr), + index_value_(nullptr) { + MS_EXCEPTION_IF_NULL(task_info_); + auto stream_list = model_context.stream_list(); + auto label_list = model_context.label_list(); + rt_model_handle_ = model_context.rt_model_handle(); + uint32_t stream_id = task_info_->stream_id(); + label_id_ = task_info_->label_id(); + MS_LOG(INFO) << "Stream list size: " << stream_list.size() << ", stream id: " << stream_id; + MS_LOG(INFO) << "Label list size: " << label_list.size() << ", label id: " << label_id_; + if (stream_id >= stream_list.size() || label_id_ >= label_list.size()) { + MS_LOG(EXCEPTION) << "Stream/Label id invalid."; + } + stream_ = stream_list[stream_id]; + label_manager_ = LabelManager::GetInstance(); + MS_EXCEPTION_IF_NULL(label_manager_); + label_info_ = label_manager_->GetLabelInfo(rt_model_handle_, {label_id_}, label_list); + MS_EXCEPTION_IF_NULL(label_info_); +} + +LabelGotoTask::~LabelGotoTask() { + if (index_value_ != nullptr) { + rtError_t rt_ret = rtFree(index_value_); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(ERROR) << "Call rtFree index_value_ failed, ret: " << std::hex << rt_ret; + } + index_value_ = nullptr; + } +} + +void LabelGotoTask::Distribute() { + MS_LOG(INFO) << "LabelGotoTask Distribute start."; + MS_EXCEPTION_IF_NULL(stream_); + MS_EXCEPTION_IF_NULL(label_info_); + + if (index_value_ == nullptr) { + rtError_t rt_ret = rtMalloc(&index_value_, sizeof(uint64_t), RT_MEMORY_HBM); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rt api rtMalloc failed, ret: " << std::hex << rt_ret; + } + + uint64_t index = 0; + rt_ret = rtMemcpy(index_value_, sizeof(uint64_t), &index, sizeof(index), RT_MEMCPY_HOST_TO_DEVICE); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rt api rtMemcpy failed, ret: " << std::hex << rt_ret; + } + } + + void *label_info = label_info_->GetLabelInfo(); + rtError_t rt_ret = rtLabelSwitchByIndex(index_value_, 1, label_info, stream_); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rt api rtLabelSwitchByIndex failed, ret: " << std::hex << rt_ret; + } + + MS_LOG(INFO) << "DistributeTask end."; +} + +REGISTER_TASK(TaskInfoType::LABEL_GOTO, LabelGotoTask, LabelGotoTaskInfo); +} // namespace mindspore::ge::model_runner diff --git a/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/label_goto_task.h b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/label_goto_task.h new file mode 100644 index 00000000000..6048617c4dd --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/label_goto_task.h @@ -0,0 +1,46 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ + +#include +#include +#include +#include +#include "runtime/device/ascend/ge_runtime/task/task.h" +#include "runtime/device/ascend/ge_runtime/task/label_manager.h" + +namespace mindspore::ge::model_runner { +class LabelGotoTask : public TaskRepeater { + public: + LabelGotoTask(const ModelContext &model_context, const std::shared_ptr &task_info); + + ~LabelGotoTask() override; + + void Distribute() override; + + private: + std::shared_ptr task_info_; + void *stream_; + std::shared_ptr label_info_; + void *index_value_; + uint32_t label_id_; + rtModel_t rt_model_handle_; + std::shared_ptr label_manager_; +}; +} // namespace mindspore::ge::model_runner +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/label_manager.cc b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/label_manager.cc new file mode 100644 index 00000000000..9d7042424bf --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/label_manager.cc @@ -0,0 +1,116 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "runtime/device/ascend/ge_runtime/task/label_manager.h" +#include +#include +#include "runtime/mem.h" +#include "runtime/rt_model.h" +#include "mindspore/core/utils/log_adapter.h" + +namespace mindspore::ge::model_runner { +std::weak_ptr LabelManager::instance_; +std::mutex LabelManager::instance_mutex_; + +template +static std::string GetVectorString(const std::vector &vec) { + std::string ret; + for (size_t i = 0; i < vec.size(); ++i) { + if (i != 0) { + ret.push_back(','); + } + ret += std::to_string(vec[i]); + } + return ret; +} + +LabelGuard::~LabelGuard() { + void *label_info = GetLabelInfo(); + if (label_info != nullptr) { + rtError_t rt_ret = rtFree(label_info); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(ERROR) << "rtFree label_info failed! ret: " << std::hex << rt_ret; + } + } +} + +std::shared_ptr LabelManager::GetInstance() { + std::lock_guard lock(instance_mutex_); + auto instance = instance_.lock(); + if (instance != nullptr) { + return instance; + } + + instance = std::make_shared(); + instance_ = instance; + return instance; +} + +std::shared_ptr LabelManager::GetLabelInfo(rtModel_t model, const std::vector &label_ids, + const std::vector &all_label) { + std::lock_guard lock(model_info_mapping_mutex_); + rtError_t rt_ret; + auto model_iter = model_info_mapping_.find(model); + if (model_iter == model_info_mapping_.end()) { + model_info_mapping_.emplace(model, std::map>()); + model_iter = model_info_mapping_.find(model); + } + + std::string label_id_str = GetVectorString(label_ids); + auto &label_map = model_iter->second; + auto label_iter = label_map.find(label_id_str); + if (label_iter != label_map.end()) { + auto label_guard = label_iter->second.lock(); + if (label_guard != nullptr) { + MS_LOG(INFO) << "model " << model << " find same label id " << label_id_str; + return label_guard; + } + } + + MS_LOG(INFO) << "Alloc label id " << label_id_str << " for model " << model; + void *label_info = nullptr; + std::vector label_list; + bool status = true; + std::transform(label_ids.begin(), label_ids.end(), std::back_inserter(label_list), + [&all_label, &status](uint32_t idx) -> void * { + if (idx >= all_label.size()) { + MS_LOG(ERROR) << "Invalid label id " << idx << " all label list size " << all_label.size(); + status = false; + return nullptr; + } + return all_label[idx]; + }); + if (!status) { + MS_LOG(ERROR) << "Get label info failed."; + return nullptr; + } + uint32_t label_info_size = sizeof(rtLabelDevInfo) * label_list.size(); + rt_ret = rtMalloc(&label_info, label_info_size, RT_MEMORY_HBM); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(ERROR) << "Call rt api rtMalloc failed, ret: " << std::hex << rt_ret; + return nullptr; + } + + rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info, label_info_size); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(ERROR) << "Call rt api rtLabelListCpy failed, ret: " << std::hex << rt_ret; + return nullptr; + } + + auto label_guard = std::make_shared(label_info); + label_map.emplace(label_id_str, label_guard); + return label_guard; +} +} // namespace mindspore::ge::model_runner diff --git a/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/label_manager.h b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/label_manager.h new file mode 100644 index 00000000000..230bf8ed453 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/label_manager.h @@ -0,0 +1,51 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_MANAGER_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_MANAGER_H_ + +#include +#include +#include +#include +#include +#include "runtime/base.h" + +namespace mindspore::ge::model_runner { +class LabelGuard { + public: + explicit LabelGuard(void *label_info) : label_info_(reinterpret_cast(label_info)) {} + ~LabelGuard(); + void *GetLabelInfo() { return reinterpret_cast(label_info_); } + + private: + uintptr_t label_info_; +}; + +class LabelManager { + public: + static std::shared_ptr GetInstance(); + std::shared_ptr GetLabelInfo(rtModel_t model, const std::vector &label_ids, + const std::vector &all_label); + + private: + std::mutex model_info_mapping_mutex_; + std::map>> model_info_mapping_; + + static std::weak_ptr instance_; + static std::mutex instance_mutex_; +}; +} // namespace mindspore::ge::model_runner +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_MANAGER_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/label_set_task.cc b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/label_set_task.cc new file mode 100644 index 00000000000..52a1e883a06 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/label_set_task.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/ge_runtime/task/label_set_task.h" +#include "runtime/device/ascend/ge_runtime/task/task_factory.h" + +namespace mindspore::ge::model_runner { +LabelSetTask::LabelSetTask(const ModelContext &model_context, const std::shared_ptr &task_info) + : TaskRepeater(model_context, task_info), + task_info_(task_info), + stream_(nullptr), + label_(nullptr) { + MS_EXCEPTION_IF_NULL(task_info_); + auto stream_list = model_context.stream_list(); + auto label_list = model_context.label_list(); + uint32_t stream_id = task_info->stream_id(); + uint32_t label_id = task_info->label_id(); + MS_LOG(INFO) << "Stream list size: " << stream_list.size() << ", stream id: " << stream_id; + MS_LOG(INFO) << "Label list size: " << label_list.size() << ", label id: " << label_id; + if (stream_id >= stream_list.size() || label_id >= label_list.size()) { + MS_LOG(EXCEPTION) << "Stream/Label id invalid."; + } + stream_ = stream_list[stream_id]; + label_ = label_list[label_id]; +} + +LabelSetTask::~LabelSetTask() {} + +void LabelSetTask::Distribute() { + MS_LOG(INFO) << "LabelSetTask Distribute start."; + MS_EXCEPTION_IF_NULL(stream_); + MS_EXCEPTION_IF_NULL(label_); + + rtError_t rt_ret = rtLabelSet(label_, stream_); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rt api rtLabelSet failed, ret: " << std::hex << rt_ret; + } + + MS_LOG(INFO) << "DistributeTask end."; +} + +REGISTER_TASK(TaskInfoType::LABEL_SET, LabelSetTask, LabelSetTaskInfo); +} // namespace mindspore::ge::model_runner diff --git a/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/label_set_task.h b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/label_set_task.h new file mode 100644 index 00000000000..54a5d1027c7 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/label_set_task.h @@ -0,0 +1,38 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ + +#include +#include "runtime/device/ascend/ge_runtime/task/task.h" + +namespace mindspore::ge::model_runner { +class LabelSetTask : public TaskRepeater { + public: + LabelSetTask(const ModelContext &model_context, const std::shared_ptr &task_info); + + ~LabelSetTask() override; + + void Distribute() override; + + private: + std::shared_ptr task_info_; + void *stream_; + void *label_; +}; +} // namespace mindspore::ge::model_runner +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/label_switch_task.cc b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/label_switch_task.cc new file mode 100644 index 00000000000..4672c0d3e40 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/label_switch_task.cc @@ -0,0 +1,77 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/ge_runtime/task/label_switch_task.h" +#include "runtime/device/ascend/ge_runtime/task/task_factory.h" + +namespace mindspore::ge::model_runner { +LabelSwitchTask::LabelSwitchTask(const ModelContext &model_context, + const std::shared_ptr &task_info) + : TaskRepeater(model_context, task_info), + task_info_(task_info), + stream_(nullptr), + label_info_(nullptr) { + MS_EXCEPTION_IF_NULL(task_info); + + rt_model_handle_ = model_context.rt_model_handle(); + auto all_label_resource = model_context.label_list(); + auto stream_list = model_context.stream_list(); + uint32_t stream_id = task_info->stream_id(); + MS_LOG(INFO) << "Stream list size: " << stream_list.size() << ", stream id: " << stream_id; + if (stream_id >= stream_list.size()) { + MS_LOG(EXCEPTION) << "Stream id invalid."; + } + stream_ = stream_list[stream_id]; + label_manager_ = LabelManager::GetInstance(); + MS_EXCEPTION_IF_NULL(label_manager_); + label_info_ = label_manager_->GetLabelInfo(rt_model_handle_, task_info_->label_list(), all_label_resource); + MS_EXCEPTION_IF_NULL(label_info_); +} + +LabelSwitchTask::~LabelSwitchTask() {} + +void LabelSwitchTask::Distribute() { + MS_LOG(INFO) << "LabelSwitchTask Distribute start."; + CheckParamValid(); + + void *label_info = label_info_->GetLabelInfo(); + rtError_t rt_ret = rtLabelSwitchByIndex(task_info_->cond(), task_info_->label_size(), label_info, stream_); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rt api rtLabelSwitchByIndex failed, ret: " << std::hex << rt_ret; + } + + MS_LOG(INFO) << "DistributeTask end."; +} + +void LabelSwitchTask::CheckParamValid() { + MS_EXCEPTION_IF_NULL(stream_); + + if (task_info_->label_list().empty()) { + MS_LOG(EXCEPTION) << "label_list is empty."; + } + + if (task_info_->label_size() != task_info_->label_list().size()) { + MS_LOG(EXCEPTION) << "label_list size " << task_info_->label_list().size() << " but label_size is " + << task_info_->label_size(); + } + + if (task_info_->label_size() >= UINT32_MAX / sizeof(rtLabelDevInfo)) { + MS_LOG(EXCEPTION) << "label_size " << task_info_->label_size() << " will overflow."; + } +} + +REGISTER_TASK(TaskInfoType::LABEL_SWITCH, LabelSwitchTask, LabelSwitchTaskInfo); +} // namespace mindspore::ge::model_runner diff --git a/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/label_switch_task.h b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/label_switch_task.h new file mode 100644 index 00000000000..ea4e1551456 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/label_switch_task.h @@ -0,0 +1,43 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ + +#include +#include "runtime/device/ascend/ge_runtime/task/task.h" +#include "runtime/device/ascend/ge_runtime/task/label_manager.h" + +namespace mindspore::ge::model_runner { +class LabelSwitchTask : public TaskRepeater { + public: + LabelSwitchTask(const ModelContext &model_context, const std::shared_ptr &task_info); + + ~LabelSwitchTask() override; + + void Distribute() override; + + private: + void CheckParamValid(); + + std::shared_ptr task_info_; + void *stream_; + rtModel_t rt_model_handle_; + std::shared_ptr label_info_; + std::shared_ptr label_manager_; +}; +} // namespace mindspore::ge::model_runner +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/memcpy_async_task.cc b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/memcpy_async_task.cc new file mode 100644 index 00000000000..1cea229930d --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/memcpy_async_task.cc @@ -0,0 +1,51 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/ge_runtime/task/memcpy_async_task.h" +#include "runtime/mem.h" +#include "runtime/device/ascend/ge_runtime/task/task_factory.h" + +namespace mindspore::ge::model_runner { +MemcpyAsyncTask::MemcpyAsyncTask(const ModelContext &model_context, + const std::shared_ptr &task_info) + : TaskRepeater(model_context, task_info), task_info_(task_info), stream_(nullptr) { + MS_EXCEPTION_IF_NULL(task_info); + auto stream_list = model_context.stream_list(); + uint32_t stream_id = task_info->stream_id(); + + MS_LOG(INFO) << "Stream list size: " << stream_list.size() << ", stream id: " << stream_id; + if (stream_id >= stream_list.size()) { + MS_LOG(EXCEPTION) << "Index: " << task_info->stream_id() << " >= stream_list.size(): " << stream_list.size(); + } + stream_ = stream_list[stream_id]; +} + +MemcpyAsyncTask::~MemcpyAsyncTask() {} + +void MemcpyAsyncTask::Distribute() { + MS_LOG(INFO) << "MemcpyAsyncTask Distribute start."; + MS_LOG(INFO) << "dst_max: " << task_info_->dst_max() << ", count: " << task_info_->count() + << ", kind: " << task_info_->kind(); + rtError_t rt_ret = rtMemcpyAsync(task_info_->dst(), task_info_->dst_max(), task_info_->src(), task_info_->count(), + static_cast(task_info_->kind()), stream_); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rt api rtMemcpyAsync failed, ret: " << std::hex << rt_ret; + } + MS_LOG(INFO) << "DistributeTask end"; +} + +REGISTER_TASK(TaskInfoType::MEMCPY_ASYNC, MemcpyAsyncTask, MemcpyAsyncTaskInfo); +} // namespace mindspore::ge::model_runner diff --git a/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/memcpy_async_task.h b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/memcpy_async_task.h new file mode 100644 index 00000000000..8e9e9ff9899 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/memcpy_async_task.h @@ -0,0 +1,37 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_MEMCPY_ASYNC_TASK_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_MEMCPY_ASYNC_TASK_H_ + +#include +#include "runtime/device/ascend/ge_runtime/task/task.h" + +namespace mindspore::ge::model_runner { +class MemcpyAsyncTask : public TaskRepeater { + public: + MemcpyAsyncTask(const ModelContext &model_context, const std::shared_ptr &task_info); + + ~MemcpyAsyncTask() override; + + void Distribute() override; + + private: + std::shared_ptr task_info_; + rtStream_t stream_; +}; +} // namespace mindspore::ge::model_runner +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_MEMCPY_ASYNC_TASK_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/profiler_task.cc b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/profiler_task.cc new file mode 100644 index 00000000000..e2bf4568fee --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/profiler_task.cc @@ -0,0 +1,47 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/ge_runtime/task/profiler_task.h" +#include "runtime/device/ascend/ge_runtime/task/task_factory.h" + +namespace mindspore::ge::model_runner { +ProfilerTask::ProfilerTask(const ModelContext &model_context, const std::shared_ptr &task_info) + : TaskRepeater(model_context, task_info), task_info_(task_info), stream_(nullptr) { + MS_EXCEPTION_IF_NULL(task_info); + auto stream_list = model_context.stream_list(); + uint32_t stream_id = task_info->stream_id(); + MS_LOG(INFO) << "Stream list size: " << stream_list.size() << ", stream id: " << stream_id; + if (stream_id >= stream_list.size()) { + MS_LOG(EXCEPTION) << "Index: " << task_info->stream_id() << " >= stream_list.size(): " << stream_list.size(); + } + stream_ = stream_list[stream_id]; +} + +ProfilerTask::~ProfilerTask() {} + +void ProfilerTask::Distribute() { + MS_LOG(INFO) << "ProfilerTask Distribute start."; + MS_LOG(INFO) << "log id = " << task_info_->log_id() << ", notify = " << task_info_->notify() + << ", flat = " << task_info_->flat(); + rtError_t rt_ret = rtProfilerTrace(task_info_->log_id(), task_info_->notify(), task_info_->flat(), stream_); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rt api rtProfilerTrace failed, ret: " << std::hex << rt_ret; + } + MS_LOG(INFO) << "DistributeTask end."; +} + +REGISTER_TASK(TaskInfoType::PROFILER_TRACE, ProfilerTask, ProfilerTraceTaskInfo); +} // namespace mindspore::ge::model_runner diff --git a/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/profiler_task.h b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/profiler_task.h new file mode 100644 index 00000000000..f0be7da02d1 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/profiler_task.h @@ -0,0 +1,37 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_PROFILER_TASK_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_PROFILER_TASK_H_ + +#include +#include "runtime/device/ascend/ge_runtime/task/task.h" + +namespace mindspore::ge::model_runner { +class ProfilerTask : public TaskRepeater { + public: + ProfilerTask(const ModelContext &model_context, const std::shared_ptr &task_info); + + ~ProfilerTask() override; + + void Distribute() override; + + private: + std::shared_ptr task_info_; + rtStream_t stream_; +}; +} // namespace mindspore::ge::model_runner +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_PROFILER_TASK_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/stream_active_task.cc b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/stream_active_task.cc new file mode 100644 index 00000000000..ff7d31f669c --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/stream_active_task.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/ge_runtime/task/stream_active_task.h" +#include "runtime/kernel.h" +#include "runtime/device/ascend/ge_runtime/task/task_factory.h" + +namespace mindspore::ge::model_runner { +StreamActiveTask::StreamActiveTask(const ModelContext &model_context, + const std::shared_ptr &task_info) + : TaskRepeater(model_context, task_info), + task_info_(task_info), + stream_(nullptr), + active_stream_(nullptr) { + MS_EXCEPTION_IF_NULL(task_info); + auto stream_list = model_context.stream_list(); + uint32_t stream_id = task_info->stream_id(); + uint32_t active_stream_id = task_info->active_stream_id(); + MS_LOG(INFO) << "Stream list size: " << stream_list.size() << ", stream id: " << stream_id + << ", active stream id: " << active_stream_id; + if (stream_id >= stream_list.size() || active_stream_id >= stream_list.size()) { + MS_LOG(EXCEPTION) << "Stream id invalid"; + } + stream_ = stream_list[stream_id]; + active_stream_ = stream_list[active_stream_id]; +} + +StreamActiveTask::~StreamActiveTask() {} + +void StreamActiveTask::Distribute() { + MS_LOG(INFO) << "Distribute start"; + MS_LOG(INFO) << "Stream " << task_info_->stream_id() << " active " << task_info_->active_stream_id(); + MS_EXCEPTION_IF_NULL(stream_); + MS_EXCEPTION_IF_NULL(active_stream_); + rtError_t rt_ret = rtStreamActive(active_stream_, stream_); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rt api rtStreamActive failed, ret: " << std::hex << rt_ret; + } + MS_LOG(INFO) << "DistributeTask end."; +} + +REGISTER_TASK(TaskInfoType::STREAM_ACTIVE, StreamActiveTask, StreamActiveTaskInfo); +} // namespace mindspore::ge::model_runner diff --git a/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/stream_active_task.h b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/stream_active_task.h new file mode 100644 index 00000000000..359b09cf44d --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/stream_active_task.h @@ -0,0 +1,38 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_STREAM_ACTIVE_TASK_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_STREAM_ACTIVE_TASK_H_ + +#include +#include "runtime/device/ascend/ge_runtime/task/task.h" + +namespace mindspore::ge::model_runner { +class StreamActiveTask : public TaskRepeater { + public: + StreamActiveTask(const ModelContext &model_context, const std::shared_ptr &task_info); + + ~StreamActiveTask() override; + + void Distribute() override; + + private: + std::shared_ptr task_info_; + rtStream_t stream_; + rtStream_t active_stream_; +}; +} // namespace mindspore::ge::model_runner +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_STREAM_ACTIVE_TASK_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/stream_switch_task.cc b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/stream_switch_task.cc new file mode 100644 index 00000000000..bb69bb18d68 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/stream_switch_task.cc @@ -0,0 +1,70 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/ge_runtime/task/stream_switch_task.h" +#include "runtime/kernel.h" +#include "runtime/device/ascend/ge_runtime/task/task_factory.h" + +namespace mindspore::ge::model_runner { +StreamSwitchTask::StreamSwitchTask(const ModelContext &model_context, + const std::shared_ptr &task_info) + : TaskRepeater(model_context, task_info), + task_info_(task_info), + stream_(nullptr), + stream_list_() { + MS_EXCEPTION_IF_NULL(task_info); + stream_list_ = model_context.stream_list(); + if (stream_list_.size() == 1) { + stream_ = stream_list_[0]; + } else if (stream_list_.size() > task_info->stream_id()) { + stream_ = stream_list_[task_info->stream_id()]; + } else { + MS_LOG(EXCEPTION) << "Index: " << task_info->stream_id() << " >= stream_list.size(): " << stream_list_.size(); + } +} + +StreamSwitchTask::~StreamSwitchTask() {} + +void StreamSwitchTask::Distribute() { + MS_LOG(INFO) << "Init StreamSwitchTask start."; + MS_LOG(INFO) << "Stream " << task_info_->stream_id() << " active " << task_info_->true_stream_id(); + MS_EXCEPTION_IF_NULL(stream_); + + if (static_cast(task_info_->true_stream_id()) >= stream_list_.size()) { + MS_LOG(EXCEPTION) << "true_stream_id " << task_info_->true_stream_id() << " must be less than stream_list_ size " + << stream_list_.size(); + } + + void *input = reinterpret_cast(task_info_->input_addr()); + rtCondition_t cond = static_cast(task_info_->cond()); + void *value = reinterpret_cast(task_info_->value_addr()); + rtStream_t true_stream = stream_list_[task_info_->true_stream_id()]; + rtSwitchDataType_t data_type = static_cast(task_info_->data_type()); + + MS_LOG(INFO) << "InitStreamSwitchTask, cond: " << cond << ", trueStream: " << true_stream + << ", trueStreamID: " << task_info_->true_stream_id() << ", datatype: " << task_info_->data_type(); + + MS_LOG(INFO) << "StreamSwitchTask Distribute Start."; + rtError_t rt_ret = rtStreamSwitchEx(input, cond, value, true_stream, stream_, data_type); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rt api rtStreamSwitchEx failed, ret: " << std::hex << rt_ret; + } + + MS_LOG(INFO) << "Distribute StreamSwitch success"; +} + +REGISTER_TASK(TaskInfoType::STREAM_SWITCH, StreamSwitchTask, StreamSwitchTaskInfo); +} // namespace mindspore::ge::model_runner diff --git a/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/stream_switch_task.h b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/stream_switch_task.h new file mode 100644 index 00000000000..686bf6bb1e4 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/stream_switch_task.h @@ -0,0 +1,40 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_STREAM_SWITCH_TASK_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_STREAM_SWITCH_TASK_H_ + +#include +#include +#include "runtime/device/ascend/ge_runtime/task/task.h" + +namespace mindspore::ge::model_runner { +class StreamSwitchTask : public TaskRepeater { + public: + StreamSwitchTask(const ModelContext &model_context, const std::shared_ptr &task_info); + + ~StreamSwitchTask() override; + + void Distribute() override; + + private: + std::shared_ptr task_info_; + + void *stream_; + std::vector stream_list_; +}; +} // namespace mindspore::ge::model_runner +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_STREAM_SWITCH_TASK_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/task.h b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/task.h new file mode 100644 index 00000000000..3290f7a0012 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/task.h @@ -0,0 +1,53 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TASK_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TASK_H_ + +#include +#include +#include +#include +#include "runtime/device/ascend/ge_runtime/model_context.h" +#include "runtime/device/ascend/ge_runtime/task_info.h" + +namespace mindspore::ge::model_runner { +class Task { + public: + Task() {} + + virtual ~Task() {} + + virtual void Distribute() = 0; + + virtual void *Args() { return nullptr; } + + virtual std::string task_name() const { return ""; } +}; + +template +class TaskRepeater : public Task { + static_assert(std::is_base_of(), "Wrong TaskInfo Type!"); + + public: + TaskRepeater(const ModelContext &model_context, const std::shared_ptr &task_info) {} + + virtual ~TaskRepeater() {} + + virtual void Distribute() = 0; +}; +} // namespace mindspore::ge::model_runner +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TASK_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/task_factory.h b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/task_factory.h new file mode 100644 index 00000000000..4dc284fc5f6 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/task_factory.h @@ -0,0 +1,84 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TASK_FACTORY_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TASK_FACTORY_H_ + +#include +#include +#include +#include +#include "runtime/device/ascend/ge_runtime/task_info.h" +#include "mindspore/core/utils/log_adapter.h" + +namespace mindspore::ge::model_runner { +class Task; +class ModelContext; +using TASK_CREATOR_FUN = std::function(const ModelContext &, std::shared_ptr)>; + +class TaskFactory { + private: + TaskFactory() {} + ~TaskFactory() {} + void RegisterCreator(const TaskInfoType &type, const TASK_CREATOR_FUN &func) { + if (creator_map_.find(type) != creator_map_.end()) { + MS_LOG(WARNING) << "Creator type " << type << " already exist."; + } + creator_map_[type] = func; + } + + std::map creator_map_; + + public: + static TaskFactory &GetInstance() { + static TaskFactory instance; + return instance; + } + + std::shared_ptr Create(const ModelContext &model_context, const std::shared_ptr &task_info) const { + if (task_info == nullptr) { + MS_LOG(ERROR) << "task_info is null."; + return nullptr; + } + + auto iter = creator_map_.find(task_info->type()); + if (iter == creator_map_.end()) { + MS_LOG(ERROR) << "Unknown task type " << task_info->type(); + return nullptr; + } + return iter->second(model_context, task_info); + } + + class Register { + public: + Register(const TaskInfoType &type, const TASK_CREATOR_FUN &func) { + MS_LOG(DEBUG) << "register type " << type; + TaskFactory::GetInstance().RegisterCreator(type, func); + } + + ~Register() {} + }; +}; + +#define REGISTER_TASK(type, task_clazz, task_info_clazz) \ + TaskFactory::Register g_##task_clazz##_register( \ + type, [](const ModelContext &model_context, const std::shared_ptr &task_info) -> std::shared_ptr { \ + std::shared_ptr concrete_task_info = std::static_pointer_cast(task_info); \ + return std::make_shared(model_context, concrete_task_info); \ + }); + +} // namespace mindspore::ge::model_runner +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TASK_FACTORY_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/tbe_task.cc b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/tbe_task.cc new file mode 100644 index 00000000000..00d35f4f13f --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/tbe_task.cc @@ -0,0 +1,97 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/ge_runtime/task/tbe_task.h" +#include +#include "runtime/mem.h" +#include "runtime/kernel.h" +#include "runtime/device/ascend/ge_runtime/task/task_factory.h" + +namespace mindspore::ge::model_runner { +TbeTask::TbeTask(const ModelContext &model_context, const std::shared_ptr &task_info) + : TaskRepeater(model_context, task_info), + task_info_(task_info), + stream_(nullptr), + stub_func_(nullptr), + args_(nullptr) { + MS_EXCEPTION_IF_NULL(task_info); + + auto stream_list = model_context.stream_list(); + if (stream_list.size() == 1) { + stream_ = stream_list[0]; + } else if (stream_list.size() > task_info->stream_id()) { + stream_ = stream_list[task_info->stream_id()]; + } else { + MS_LOG(EXCEPTION) << "Index: " << task_info->stream_id() << " >= stream_list.size(): " << stream_list.size(); + } +} + +TbeTask::~TbeTask() { + if (args_ != nullptr) { + rtError_t rt_ret = rtFree(args_); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(ERROR) << "Call rt api rtFree failed, ret: " << std::hex << rt_ret; + } + args_ = nullptr; + } +} + +void TbeTask::Distribute() { + MS_LOG(INFO) << "InitTbeTask start."; + MS_EXCEPTION_IF_NULL(stream_); + // Get stub_func + if (task_info_->stub_func().empty()) { + MS_LOG(EXCEPTION) << "kernel_info->stub_func is empty!"; + } + + rtError_t rt_ret = rtGetFunctionByName(const_cast(task_info_->stub_func().c_str()), &stub_func_); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rt api rtGetFunctionByName failed, ret: " << std::hex << rt_ret; + } + MS_LOG(INFO) << "TbeTask: stub_func = " << task_info_->stub_func(); + + // Get args + std::vector tensor_device_addrs; + tensor_device_addrs.insert(tensor_device_addrs.end(), task_info_->input_data_addrs().begin(), + task_info_->input_data_addrs().end()); + tensor_device_addrs.insert(tensor_device_addrs.end(), task_info_->output_data_addrs().begin(), + task_info_->output_data_addrs().end()); + tensor_device_addrs.insert(tensor_device_addrs.end(), task_info_->workspace_addrs().begin(), + task_info_->workspace_addrs().end()); + auto args_size = static_cast(tensor_device_addrs.size() * sizeof(void *)); + + rt_ret = rtMalloc(&args_, args_size, RT_MEMORY_HBM); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rt api rtMalloc failed, ret: " << std::hex << rt_ret << " mem size " << args_size; + } + + rt_ret = rtMemcpy(args_, args_size, reinterpret_cast(tensor_device_addrs.data()), args_size, + RT_MEMCPY_HOST_TO_DEVICE); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rt api rtMemcpy failed, ret: " << std::hex << rt_ret; + } + + MS_LOG(INFO) << "DistributeTbeTask start."; + auto dump_flag = task_info_->dump_flag() ? RT_KERNEL_DUMPFLAG : RT_KERNEL_DEFAULT; + rt_ret = rtKernelLaunchWithFlag(stub_func_, task_info_->block_dim(), args_, args_size, nullptr, stream_, dump_flag); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rt api rtKernelLaunch failed, ret: " << std::hex << rt_ret << " mem size " << args_size; + } + MS_LOG(INFO) << "[DataDump] task name: " << task_info_->op_name() << " dump_flag: " << dump_flag; +} + +REGISTER_TASK(TaskInfoType::TBE, TbeTask, TbeTaskInfo); +} // namespace mindspore::ge::model_runner diff --git a/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/tbe_task.h b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/tbe_task.h new file mode 100644 index 00000000000..5c75223feb6 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task/tbe_task.h @@ -0,0 +1,44 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TBE_TASK_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TBE_TASK_H_ + +#include +#include +#include "runtime/device/ascend/ge_runtime/task/task.h" + +namespace mindspore::ge::model_runner { +class TbeTask : public TaskRepeater { + public: + TbeTask(const ModelContext &model_context, const std::shared_ptr &task_info); + + ~TbeTask() override; + + void Distribute() override; + + void *Args() override { return args_; } + + std::string task_name() const override { return task_info_->op_name(); } + + private: + std::shared_ptr task_info_; + void *stream_; + void *stub_func_; + void *args_; +}; +} // namespace mindspore::ge::model_runner +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TBE_TASK_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task_info.h b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task_info.h new file mode 100644 index 00000000000..6afb0c3c839 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ge_runtime/task_info.h @@ -0,0 +1,364 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_INFO_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_INFO_H_ + +#include +#include +#include +#include +#include + +namespace mindspore::ge::model_runner { +enum TaskInfoType { + CCE = 0, + TBE, + AICPU, + LABEL_SET, + LABEL_SWITCH, + LABEL_GOTO, + EVENT_RECORD, + EVENT_WAIT, + FUSION_START, + FUSION_END, + HCCL, + PROFILER_TRACE, + MEMCPY_ASYNC, + STREAM_SWITCH, + STREAM_ACTIVE, + // Insert new task type here + REVSERVED = 23 +}; + +class TaskInfo { + public: + virtual ~TaskInfo() {} + uint32_t stream_id() const { return stream_id_; } + TaskInfoType type() const { return type_; } + std::string op_name() const { return op_name_; } + bool dump_flag() const { return dump_flag_; } + + protected: + TaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, bool dump_flag) + : op_name_(op_name), stream_id_(stream_id), type_(type), dump_flag_(dump_flag) {} + + private: + std::string op_name_; + uint32_t stream_id_; + TaskInfoType type_; + bool dump_flag_; +}; + +class TbeTaskInfo : public TaskInfo { + public: + TbeTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string &stub_func, uint32_t block_dim, + const std::vector &args, uint32_t args_size, const std::vector &sm_desc, void *binary, + uint32_t binary_size, const std::vector &meta_data, const std::vector &input_data_addrs, + const std::vector &output_data_addrs, const std::vector &workspace_addrs, bool dump_flag) + : TaskInfo(op_name, stream_id, TaskInfoType::TBE, dump_flag), + stub_func_(stub_func), + block_dim_(block_dim), + args_(args), + args_size_(args_size), + sm_desc_(sm_desc), + binary_(binary), + binary_size_(binary_size), + meta_data_(meta_data), + input_data_addrs_(input_data_addrs), + output_data_addrs_(output_data_addrs), + workspace_addrs_(workspace_addrs) {} + ~TbeTaskInfo() override {} + + const std::string &stub_func() const { return stub_func_; } + uint32_t block_dim() const { return block_dim_; } + const std::vector &args() const { return args_; } + uint32_t args_size() const { return args_size_; } + const std::vector &sm_desc() const { return sm_desc_; } + void *binary() const { return binary_; } + uint32_t binary_size() const { return binary_size_; } + const std::vector &meta_data() const { return meta_data_; } + const std::vector &input_data_addrs() const { return input_data_addrs_; } + const std::vector &output_data_addrs() const { return output_data_addrs_; } + const std::vector &workspace_addrs() const { return workspace_addrs_; } + + void SetBinary(void *binary, uint32_t binary_size) { + binary_ = binary; + binary_size_ = binary_size; + } + + private: + std::string stub_func_; + uint32_t block_dim_; + std::vector args_; + uint32_t args_size_; + std::vector sm_desc_; + void *binary_; + uint32_t binary_size_; + std::vector meta_data_; + std::vector input_data_addrs_; + std::vector output_data_addrs_; + std::vector workspace_addrs_; +}; + +class AicpuTaskInfo : public TaskInfo { + public: + AicpuTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string &so_name, + const std::string &kernel_name, const std::string &node_def, const std::string &ext_info, + const std::vector &input_data_addrs, const std::vector &output_data_addrs, + bool dump_flag) + : TaskInfo(op_name, stream_id, TaskInfoType::AICPU, dump_flag), + so_name_(so_name), + kernel_name_(kernel_name), + node_def_(node_def), + ext_info_(ext_info), + input_data_addrs_(input_data_addrs), + output_data_addrs_(output_data_addrs) {} + ~AicpuTaskInfo() override {} + + const std::string &so_name() const { return so_name_; } + const std::string &kernel_name() const { return kernel_name_; } + const std::string &node_def() const { return node_def_; } + const std::vector &input_data_addrs() const { return input_data_addrs_; } + const std::vector &output_data_addrs() const { return output_data_addrs_; } + const std::string &ext_info() const { return ext_info_; } + + private: + std::string so_name_; + std::string kernel_name_; + std::string node_def_; + std::string ext_info_; + std::vector input_data_addrs_; + std::vector output_data_addrs_; +}; + +class LabelSetTaskInfo : public TaskInfo { + public: + LabelSetTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id) + : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SET, false), label_id_(label_id) {} + ~LabelSetTaskInfo() override {} + uint32_t label_id() const { return label_id_; } + + private: + uint32_t label_id_; +}; + +class LabelGotoTaskInfo : public TaskInfo { + public: + LabelGotoTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id) + : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_GOTO, false), label_id_(label_id) {} + ~LabelGotoTaskInfo() override {} + uint32_t label_id() const { return label_id_; } + + private: + uint32_t label_id_; +}; + +class LabelSwitchTaskInfo : public TaskInfo { + public: + LabelSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_size, + const std::vector &label_list, void *cond) + : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SWITCH, false), + label_size_(label_size), + label_list_(label_list), + cond_(cond) {} + ~LabelSwitchTaskInfo() override {} + uint32_t label_size() const { return label_size_; } + const std::vector &label_list() const { return label_list_; } + void *cond() const { return cond_; } + + private: + uint32_t label_size_; + std::vector label_list_; + void *cond_; +}; + +class EventTaskInfo : public TaskInfo { + public: + uint32_t event_id() const { return event_id_; } + + protected: + EventTaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, uint32_t event_id) + : TaskInfo(op_name, stream_id, type, false), event_id_(event_id) {} + ~EventTaskInfo() override {} + + uint32_t event_id_; +}; + +class EventRecordTaskInfo : public EventTaskInfo { + public: + EventRecordTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id) + : EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_RECORD, event_id) {} + ~EventRecordTaskInfo() override {} +}; + +class EventWaitTaskInfo : public EventTaskInfo { + public: + EventWaitTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id) + : EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_WAIT, event_id) {} + ~EventWaitTaskInfo() override {} +}; + +class FusionStartTaskInfo : public TaskInfo { + public: + explicit FusionStartTaskInfo(const std::string &op_name, uint32_t stream_id) + : TaskInfo(op_name, stream_id, TaskInfoType::FUSION_START, false) {} + ~FusionStartTaskInfo() override {} +}; + +class FusionEndTaskInfo : public TaskInfo { + public: + explicit FusionEndTaskInfo(const std::string &op_name, uint32_t stream_id) + : TaskInfo(op_name, stream_id, TaskInfoType::FUSION_END, false) {} + ~FusionEndTaskInfo() override {} +}; + +class HcclTaskInfo : public TaskInfo { + public: + HcclTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string hccl_type, void *input_data_addr, + void *output_data_addr, void *workspace_addr, int64_t workspace_size, int64_t hccl_stream_num, + const std::vector &private_def, void *ops_kernel_store, int32_t count, int64_t root_id, + int64_t op_type, int64_t data_type, const std::string &group, bool dump_flag) + : TaskInfo(op_name, stream_id, TaskInfoType::HCCL, dump_flag), + hccl_type_(hccl_type), + input_data_addr_(input_data_addr), + output_data_addr_(output_data_addr), + workspace_addr_(workspace_addr), + workspace_size_(workspace_size), + hccl_stream_num_(hccl_stream_num), + private_def_(private_def), + ops_kernel_store_(ops_kernel_store), + count_(count), + root_id_(root_id), + op_type_(op_type), + data_type_(data_type), + group_(group) {} + ~HcclTaskInfo() override {} + + const std::string &hccl_type() const { return hccl_type_; } + void *input_data_addr() const { return input_data_addr_; } + void *output_data_addr() const { return output_data_addr_; } + void *workspace_addr() const { return workspace_addr_; } + int64_t workspace_size() const { return workspace_size_; } + int64_t hccl_stream_num() const { return hccl_stream_num_; } + const std::vector &private_def() const { return private_def_; } + void *ops_kernel_store() const { return ops_kernel_store_; } + int32_t count() const { return count_; } + int64_t root_id() const { return root_id_; } + int64_t op_type() const { return op_type_; } + int64_t data_type() const { return data_type_; } + const std::string &group() const { return group_; } + + private: + std::string hccl_type_; + void *input_data_addr_; + void *output_data_addr_; + void *workspace_addr_; + int64_t workspace_size_; + int64_t hccl_stream_num_; + std::vector private_def_; + void *ops_kernel_store_; + int32_t count_; + int64_t root_id_; + int64_t op_type_; + int64_t data_type_; + std::string group_; +}; + +class ProfilerTraceTaskInfo : public TaskInfo { + public: + ProfilerTraceTaskInfo(const std::string &op_name, uint32_t stream_id, uint64_t log_id, bool notify, uint32_t flat) + : TaskInfo(op_name, stream_id, TaskInfoType::PROFILER_TRACE, false), + log_id_(log_id), + notify_(notify), + flat_(flat) {} + ~ProfilerTraceTaskInfo() override {} + + uint64_t log_id() const { return log_id_; } + bool notify() const { return notify_; } + uint32_t flat() const { return flat_; } + + private: + uint64_t log_id_; + bool notify_; + uint32_t flat_; +}; + +class MemcpyAsyncTaskInfo : public TaskInfo { + public: + MemcpyAsyncTaskInfo(const std::string &op_name, uint32_t stream_id, void *dst, uint64_t dst_max, void *src, + uint64_t count, uint32_t kind, bool dump_flag) + : TaskInfo(op_name, stream_id, TaskInfoType::MEMCPY_ASYNC, dump_flag), + dst_(dst), + dst_max_(dst_max), + src_(src), + count_(count), + kind_(kind) {} + ~MemcpyAsyncTaskInfo() override {} + + void *dst() const { return dst_; } + uint64_t dst_max() const { return dst_max_; } + void *src() const { return src_; } + uint64_t count() const { return count_; } + uint32_t kind() const { return kind_; } + + private: + void *dst_; + uint64_t dst_max_; + void *src_; + uint64_t count_; + int32_t kind_; +}; + +class StreamSwitchTaskInfo : public TaskInfo { + public: + StreamSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, int64_t true_stream_id, void *input_addr, + void *value_addr, int64_t cond, int64_t data_type) + : TaskInfo(op_name, stream_id, TaskInfoType::STREAM_SWITCH, false), + true_stream_id_(true_stream_id), + input_addr_(input_addr), + value_addr_(value_addr), + cond_(cond), + data_type_(data_type) {} + ~StreamSwitchTaskInfo() override {} + + int64_t true_stream_id() const { return true_stream_id_; } + void *input_addr() const { return input_addr_; } + void *value_addr() const { return value_addr_; } + int64_t cond() const { return cond_; } + int64_t data_type() const { return data_type_; } + + private: + int64_t true_stream_id_; + void *input_addr_; + void *value_addr_; + int64_t cond_; + int64_t data_type_; +}; + +class StreamActiveTaskInfo : public TaskInfo { + public: + StreamActiveTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t active_stream_id) + : TaskInfo(op_name, stream_id, TaskInfoType::STREAM_ACTIVE, false), active_stream_id_(active_stream_id) {} + ~StreamActiveTaskInfo() override {} + + uint32_t active_stream_id() const { return active_stream_id_; } + + private: + uint32_t active_stream_id_; +}; +} // namespace mindspore::ge::model_runner +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_INFO_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.h b/mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.h index a77e5d0c06c..60007668748 100644 --- a/mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.h +++ b/mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.h @@ -25,7 +25,7 @@ #include "runtime/device/kernel_runtime.h" #include "ir/anf.h" #include "backend/kernel_compiler/ascend_kernel_mod.h" -#include "framework/ge_runtime/task_info.h" +#include "runtime/device/ascend/ge_runtime/task_info.h" namespace mindspore { namespace device { diff --git a/mindspore/core/utils/log_adapter.h b/mindspore/core/utils/log_adapter.h index 4f2d539f184..54d9f08d880 100644 --- a/mindspore/core/utils/log_adapter.h +++ b/mindspore/core/utils/log_adapter.h @@ -134,6 +134,7 @@ enum SubModuleId : int { SM_HCCL_ADPT, // Hccl Adapter SM_MINDQUANTUM, // MindQuantum SM_RUNTIME_FRAMEWORK, // Runtime framework + SM_GE, // GraphEngine NUM_SUBMODUES // number of submodules }; @@ -142,34 +143,35 @@ enum SubModuleId : int { #endif static const char *SUB_MODULE_NAMES[NUM_SUBMODUES] = { - "UNKNOWN", // SM_UNKNOWN - "CORE", // SM_CORE - "ANALYZER", // SM_ANALYZER - "COMMON", // SM_COMMON - "DEBUG", // SM_DEBUG - "OFFLINE_DEBUG", // SM_OFFLINE_DEBUG - "DEVICE", // SM_DEVICE - "GE_ADPT", // SM_GE_ADPT - "IR", // SM_IR - "KERNEL", // SM_KERNEL - "MD", // SM_MD - "ME", // SM_ME - "EXPRESS", // SM_EXPRESS - "OPTIMIZER", // SM_OPTIMIZER - "PARALLEL", // SM_PARALLEL - "PARSER", // SM_PARSER - "PIPELINE", // SM_PIPELINE - "PRE_ACT", // SM_PRE_ACT - "PYNATIVE", // SM_PYNATIVE - "SESSION", // SM_SESSION - "UTILS", // SM_UTILS - "VM", // SM_VM - "PROFILER", // SM_PROFILER - "PS", // SM_PS - "LITE", // SM_LITE - "HCCL_ADPT", // SM_HCCL_ADPT - "MINDQUANTUM", // SM_MINDQUANTUM - "RUNTIME_FRAMEWORK" // SM_RUNTIME_FRAMEWORK + "UNKNOWN", // SM_UNKNOWN + "CORE", // SM_CORE + "ANALYZER", // SM_ANALYZER + "COMMON", // SM_COMMON + "DEBUG", // SM_DEBUG + "OFFLINE_DEBUG", // SM_OFFLINE_DEBUG + "DEVICE", // SM_DEVICE + "GE_ADPT", // SM_GE_ADPT + "IR", // SM_IR + "KERNEL", // SM_KERNEL + "MD", // SM_MD + "ME", // SM_ME + "EXPRESS", // SM_EXPRESS + "OPTIMIZER", // SM_OPTIMIZER + "PARALLEL", // SM_PARALLEL + "PARSER", // SM_PARSER + "PIPELINE", // SM_PIPELINE + "PRE_ACT", // SM_PRE_ACT + "PYNATIVE", // SM_PYNATIVE + "SESSION", // SM_SESSION + "UTILS", // SM_UTILS + "VM", // SM_VM + "PROFILER", // SM_PROFILER + "PS", // SM_PS + "LITE", // SM_LITE + "HCCL_ADPT", // SM_HCCL_ADPT + "MINDQUANTUM", // SM_MINDQUANTUM + "RUNTIME_FRAMEWORK", // SM_RUNTIME_FRAMEWORK + "GE", // SM_GE }; #if defined(_WIN32) || defined(_WIN64) diff --git a/tests/ut/cpp/CMakeLists.txt b/tests/ut/cpp/CMakeLists.txt index 11d017da300..f484d94bc09 100644 --- a/tests/ut/cpp/CMakeLists.txt +++ b/tests/ut/cpp/CMakeLists.txt @@ -23,6 +23,7 @@ include_directories(${CMAKE_SOURCE_DIR}/mindspore/core) include_directories(${CMAKE_CURRENT_SOURCE_DIR}) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/stub/runtime/) include_directories(${CMAKE_BINARY_DIR}) +include_directories(${CMAKE_BINARY_DIR}/proto/ge) include_directories(${CUDA_INCLUDE_DIRS}) MESSAGE("check ut_test ${CMAKE_BINARY_DIR}") @@ -103,6 +104,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "../../../mindspore/ccsrc/runtime/device/bucket.cc" "../../../mindspore/ccsrc/runtime/device/launch_kernel.cc" "../../../mindspore/ccsrc/runtime/device/ascend/profiling/*.cc" + "../../../mindspore/ccsrc/runtime/device/ascend/ge_runtime/*.cc" "../../../mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc" "../../../mindspore/ccsrc/runtime/device/ascend/ascend_launch_kernel.cc" "../../../mindspore/ccsrc/runtime/device/ascend/ascend_launch_mul.cc" diff --git a/tests/ut/cpp/device/ge_runtime_test.cc b/tests/ut/cpp/device/ge_runtime_test.cc new file mode 100644 index 00000000000..eeb47ca94dd --- /dev/null +++ b/tests/ut/cpp/device/ge_runtime_test.cc @@ -0,0 +1,473 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "common/common_test.h" +#define private public +#include "runtime/device/ascend/ge_runtime/model_runner.h" +#include "runtime/device/ascend/ge_runtime/runtime_model.h" +#include "runtime/device/ascend/ge_runtime/task/task_factory.h" +#include "runtime/device/ascend/ge_runtime/task/aicpu_task.h" +#include "runtime/device/ascend/ge_runtime/task/event_record_task.h" +#include "runtime/device/ascend/ge_runtime/task/event_wait_task.h" +#include "runtime/device/ascend/ge_runtime/task/hccl_task.h" +#include "runtime/device/ascend/ge_runtime/task/label_goto_task.h" +#include "runtime/device/ascend/ge_runtime/task/label_manager.h" +#include "runtime/device/ascend/ge_runtime/task/label_set_task.h" +#include "runtime/device/ascend/ge_runtime/task/label_switch_task.h" +#include "runtime/device/ascend/ge_runtime/task/memcpy_async_task.h" +#include "runtime/device/ascend/ge_runtime/task/profiler_task.h" +#include "runtime/device/ascend/ge_runtime/task/stream_active_task.h" +#include "runtime/device/ascend/ge_runtime/task/stream_switch_task.h" +#include "runtime/device/ascend/ge_runtime/task/tbe_task.h" +#undef private +#include "common/opskernel/ops_kernel_info_store.h" + +using namespace mindspore::ge::model_runner; +using namespace testing; + +class MockOpsKernelInfoStore : public ge::OpsKernelInfoStore { + public: + ge::Status Initialize(const map &) override { return ge::SUCCESS; } + ge::Status Finalize() override { return ge::SUCCESS; } + void GetAllOpsKernelInfo(std::map &infos) const override {} + bool CheckSupported(const ge::OpDescPtr &opDescPtr, std::string &un_supported_reason) const override { return true; } + ge::Status LoadTask(ge::GETaskInfo &task) override { return ge::SUCCESS; } +}; + +namespace mindspore { +class TestAscendGeRuntime : public UT::Common { + public: + TestAscendGeRuntime() {} + + private: + void TearDown() override { + { + std::lock_guard lock(HcclTask::model_stream_mapping_mutex_); + HcclTask::model_stream_mapping_.clear(); + } + } +}; + +TEST_F(TestAscendGeRuntime, test_task_create_null_task_info_failed) { + ModelContext model_context(0, 0, 0, reinterpret_cast(1), reinterpret_cast(2), + {reinterpret_cast(1), reinterpret_cast(1)}, + {reinterpret_cast(1)}, {reinterpret_cast(1)}); + ASSERT_TRUE(TaskFactory::GetInstance().Create(model_context, nullptr) == nullptr); +} + +TEST_F(TestAscendGeRuntime, test_aicpu_task_create_one_stream_success) { + ModelContext model_context(0, 0, 0, reinterpret_cast(1), reinterpret_cast(2), + {reinterpret_cast(1)}, {reinterpret_cast(1)}, + {reinterpret_cast(1)}); + std::shared_ptr aicpu_task_info = std::make_shared( + "op_name", 0, "so_name", "kernel_name", "node_def", "ext_info", std::vector{reinterpret_cast(1)}, + std::vector{reinterpret_cast(1)}, true); + std::shared_ptr task = TaskFactory::GetInstance().Create(model_context, aicpu_task_info); + ASSERT_TRUE(std::dynamic_pointer_cast(task) != nullptr); + ASSERT_NO_THROW(task->Distribute()); +} + +TEST_F(TestAscendGeRuntime, test_aicpu_task_create_multi_stream_success) { + ModelContext model_context(0, 0, 0, reinterpret_cast(1), reinterpret_cast(2), + {reinterpret_cast(1), reinterpret_cast(1)}, + {reinterpret_cast(1)}, {reinterpret_cast(1)}); + std::shared_ptr aicpu_task_info = std::make_shared( + "op_name", 0, "so_name", "kernel_name", "node_def", "", std::vector{reinterpret_cast(1)}, + std::vector{reinterpret_cast(1)}, true); + std::shared_ptr task = TaskFactory::GetInstance().Create(model_context, aicpu_task_info); + ASSERT_TRUE(std::dynamic_pointer_cast(task) != nullptr); + ASSERT_NO_THROW(task->Distribute()); +} + +TEST_F(TestAscendGeRuntime, test_aicpu_task_create_invalid_stream_id_failed) { + ModelContext model_context(0, 0, 0, reinterpret_cast(1), reinterpret_cast(2), + {reinterpret_cast(1), reinterpret_cast(1)}, + {reinterpret_cast(1)}, {reinterpret_cast(1)}); + std::shared_ptr aicpu_task_info = std::make_shared( + "op_name", 5, "so_name", "kernel_name", "node_def", "", std::vector{reinterpret_cast(1)}, + std::vector{reinterpret_cast(1)}, true); + ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, aicpu_task_info)); +} + +TEST_F(TestAscendGeRuntime, test_event_record_task_create_success) { + ModelContext model_context(0, 0, 0, reinterpret_cast(1), reinterpret_cast(2), + {reinterpret_cast(1)}, {reinterpret_cast(1)}, + {reinterpret_cast(1)}); + std::shared_ptr event_record_task_info = std::make_shared("op_name", 0, 0); + std::shared_ptr task = TaskFactory::GetInstance().Create(model_context, event_record_task_info); + ASSERT_TRUE(std::dynamic_pointer_cast(task) != nullptr); + ASSERT_NO_THROW(task->Distribute()); +} + +TEST_F(TestAscendGeRuntime, test_event_record_task_create_invalid_event_id_failed) { + ModelContext model_context(0, 0, 0, reinterpret_cast(1), reinterpret_cast(2), + {reinterpret_cast(1)}, {reinterpret_cast(1)}, + {reinterpret_cast(1)}); + std::shared_ptr event_record_task_info = std::make_shared("op_name", 0, 10); + ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, event_record_task_info)); +} + +TEST_F(TestAscendGeRuntime, test_event_wait_task_create_success) { + ModelContext model_context(0, 0, 0, reinterpret_cast(1), reinterpret_cast(2), + {reinterpret_cast(1)}, {reinterpret_cast(1)}, + {reinterpret_cast(1)}); + std::shared_ptr event_record_task_info = std::make_shared("op_name", 0, 0); + std::shared_ptr task = TaskFactory::GetInstance().Create(model_context, event_record_task_info); + ASSERT_TRUE(std::dynamic_pointer_cast(task) != nullptr); + ASSERT_NO_THROW(task->Distribute()); +} + +TEST_F(TestAscendGeRuntime, test_event_wait_task_create_invalid_event_id_failed) { + ModelContext model_context(0, 0, 0, reinterpret_cast(1), reinterpret_cast(2), + {reinterpret_cast(1)}, {reinterpret_cast(1)}, + {reinterpret_cast(1)}); + std::shared_ptr event_record_task_info = std::make_shared("op_name", 0, 10); + ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, event_record_task_info)); +} + +TEST_F(TestAscendGeRuntime, test_hccl_task_create_success) { + MockOpsKernelInfoStore ops_kernel_info_store; + ModelContext model_context(0, 0, 0, reinterpret_cast(1), reinterpret_cast(2), + {reinterpret_cast(1)}, {reinterpret_cast(1)}, + {reinterpret_cast(1)}); + std::shared_ptr hccl_task_info = std::make_shared( + "op_name", 0, "hccl_type", reinterpret_cast(1), reinterpret_cast(2), reinterpret_cast(3), 4, + 5, std::vector(6, 7), reinterpret_cast(&ops_kernel_info_store), 9, 10, 11, 12, "group", true); + std::shared_ptr task = TaskFactory::GetInstance().Create(model_context, hccl_task_info); + ASSERT_TRUE(std::dynamic_pointer_cast(task) != nullptr); + ASSERT_NO_THROW(task->Distribute()); +} + +TEST_F(TestAscendGeRuntime, test_hccl_task_create_stream_reuse_success) { + const rtModel_t model = reinterpret_cast(0x12345678); + const rtStream_t stream = reinterpret_cast(0x87654321); + constexpr uint32_t stream_id = 0; + constexpr int64_t task1_stream_num = 3; + constexpr int64_t task2_stream_num = 5; + constexpr int64_t task3_stream_num = 4; + MockOpsKernelInfoStore ops_kernel_info_store; + ModelContext model_context(0, 0, 0, model, reinterpret_cast(2), {stream}, + {reinterpret_cast(1)}, {reinterpret_cast(1)}); + std::shared_ptr hccl_task_info_1 = std::make_shared( + "op_name", stream_id, "hccl_type", reinterpret_cast(1), reinterpret_cast(2), + reinterpret_cast(3), 4, task1_stream_num, std::vector(6, 7), + reinterpret_cast(&ops_kernel_info_store), 9, 10, 11, 12, "group", true); + std::shared_ptr hccl_task_info_2 = std::make_shared( + "op_name", stream_id, "hccl_type", reinterpret_cast(1), reinterpret_cast(2), + reinterpret_cast(3), 4, task2_stream_num, std::vector(6, 7), + reinterpret_cast(&ops_kernel_info_store), 9, 10, 11, 12, "group", true); + std::shared_ptr hccl_task_info_3 = std::make_shared( + "op_name", stream_id, "hccl_type", reinterpret_cast(1), reinterpret_cast(2), + reinterpret_cast(3), 4, task3_stream_num, std::vector(6, 7), + reinterpret_cast(&ops_kernel_info_store), 9, 10, 11, 12, "group", true); + std::shared_ptr task_1 = TaskFactory::GetInstance().Create(model_context, hccl_task_info_1); + std::shared_ptr task_2 = TaskFactory::GetInstance().Create(model_context, hccl_task_info_2); + std::shared_ptr task_3 = TaskFactory::GetInstance().Create(model_context, hccl_task_info_3); + ASSERT_TRUE(std::dynamic_pointer_cast(task_1) != nullptr); + ASSERT_TRUE(std::dynamic_pointer_cast(task_2) != nullptr); + ASSERT_TRUE(std::dynamic_pointer_cast(task_3) != nullptr); + ASSERT_NO_THROW(task_1->Distribute()); + ASSERT_NO_THROW(task_2->Distribute()); + ASSERT_NO_THROW(task_3->Distribute()); + { + std::lock_guard lock(HcclTask::model_stream_mapping_mutex_); + auto model_iter = HcclTask::model_stream_mapping_.find(model); + ASSERT_NE(model_iter, HcclTask::model_stream_mapping_.end()); + auto stream_iter = model_iter->second.find(stream_id); + ASSERT_NE(stream_iter, model_iter->second.end()); + const auto &stream_vec = stream_iter->second; + ASSERT_EQ(stream_vec.size(), std::max(task1_stream_num, std::max(task2_stream_num, task3_stream_num))); + for (const auto &s : stream_vec) { + auto shared = s.lock(); + ASSERT_TRUE(shared != nullptr); + } + } +} + +TEST_F(TestAscendGeRuntime, test_label_goto_task_create_success) { + ModelContext model_context(0, 0, 0, reinterpret_cast(1), reinterpret_cast(2), + {reinterpret_cast(1)}, {reinterpret_cast(1)}, + {reinterpret_cast(1)}); + std::shared_ptr label_goto_task_info = std::make_shared("op_name", 0, 0); + std::shared_ptr task = TaskFactory::GetInstance().Create(model_context, label_goto_task_info); + auto label_goto_task = std::dynamic_pointer_cast(task); + ASSERT_TRUE(label_goto_task != nullptr); + ASSERT_NO_THROW(task->Distribute()); + label_goto_task->index_value_ = new uint8_t[5]; +} + +TEST_F(TestAscendGeRuntime, test_label_goto_task_create_invalid_label_id_failed) { + ModelContext model_context(0, 0, 0, reinterpret_cast(1), reinterpret_cast(2), + {reinterpret_cast(1)}, {reinterpret_cast(1)}, + {reinterpret_cast(1)}); + std::shared_ptr label_goto_task_info = std::make_shared("op_name", 0, 1); + ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, label_goto_task_info)); +} + +TEST_F(TestAscendGeRuntime, test_label_goto_task_reuse_success) { + ModelContext model_context(0, 0, 0, reinterpret_cast(1), reinterpret_cast(2), + {reinterpret_cast(1)}, {reinterpret_cast(1)}, + {reinterpret_cast(1)}); + std::shared_ptr label_goto_task_info = std::make_shared("op_name", 0, 0); + std::shared_ptr task1 = TaskFactory::GetInstance().Create(model_context, label_goto_task_info); + std::shared_ptr task2 = TaskFactory::GetInstance().Create(model_context, label_goto_task_info); + auto label_goto_task_1 = std::dynamic_pointer_cast(task1); + auto label_goto_task_2 = std::dynamic_pointer_cast(task2); + ASSERT_TRUE(label_goto_task_1 != nullptr); + ASSERT_NO_THROW(task1->Distribute()); + ASSERT_TRUE(label_goto_task_2 != nullptr); + ASSERT_NO_THROW(task2->Distribute()); + ASSERT_EQ(label_goto_task_1->label_info_, label_goto_task_2->label_info_); +} + +TEST_F(TestAscendGeRuntime, test_label_set_task_create_success) { + ModelContext model_context(0, 0, 0, reinterpret_cast(1), reinterpret_cast(2), + {reinterpret_cast(1)}, {reinterpret_cast(1)}, + {reinterpret_cast(1)}); + std::shared_ptr label_set_task_info = std::make_shared("op_name", 0, 0); + std::shared_ptr task = TaskFactory::GetInstance().Create(model_context, label_set_task_info); + ASSERT_TRUE(std::dynamic_pointer_cast(task) != nullptr); + ASSERT_NO_THROW(task->Distribute()); +} + +TEST_F(TestAscendGeRuntime, test_label_set_task_create_invalid_label_id_failed) { + ModelContext model_context(0, 0, 0, reinterpret_cast(1), reinterpret_cast(2), + {reinterpret_cast(1)}, {reinterpret_cast(1)}, + {reinterpret_cast(1)}); + std::shared_ptr label_set_task_info = std::make_shared("op_name", 0, 1); + ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, label_set_task_info)); +} + +TEST_F(TestAscendGeRuntime, test_label_switch_task_create_success) { + ModelContext model_context( + 0, 0, 0, reinterpret_cast(1), reinterpret_cast(2), {reinterpret_cast(1)}, + {reinterpret_cast(1), reinterpret_cast(1)}, {reinterpret_cast(1)}); + std::shared_ptr label_switch_task_info = + std::make_shared("op_name", 0, 2, std::vector{0, 1}, reinterpret_cast(1)); + std::shared_ptr task = TaskFactory::GetInstance().Create(model_context, label_switch_task_info); + ASSERT_TRUE(std::dynamic_pointer_cast(task) != nullptr); + ASSERT_NO_THROW(task->Distribute()); +} + +TEST_F(TestAscendGeRuntime, test_label_switch_task_create_invalid_stream_id_failed) { + ModelContext model_context( + 0, 0, 0, reinterpret_cast(1), reinterpret_cast(2), {reinterpret_cast(1)}, + {reinterpret_cast(1), reinterpret_cast(1)}, {reinterpret_cast(1)}); + std::shared_ptr label_switch_task_info = + std::make_shared("op_name", 1, 2, std::vector{0, 1}, reinterpret_cast(1)); + ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, label_switch_task_info)); +} + +TEST_F(TestAscendGeRuntime, test_label_switch_task_create_invalid_label_id_failed) { + ModelContext model_context( + 0, 0, 0, reinterpret_cast(1), reinterpret_cast(2), {reinterpret_cast(1)}, + {reinterpret_cast(1), reinterpret_cast(1)}, {reinterpret_cast(1)}); + std::shared_ptr label_switch_task_info = + std::make_shared("op_name", 0, 3, std::vector{0, 1, 2}, reinterpret_cast(1)); + ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, label_switch_task_info)); +} + +TEST_F(TestAscendGeRuntime, test_label_switch_task_reuse_success) { + ModelContext model_context( + 0, 0, 0, reinterpret_cast(1), reinterpret_cast(2), {reinterpret_cast(1)}, + {reinterpret_cast(1), reinterpret_cast(1)}, {reinterpret_cast(1)}); + std::shared_ptr label_switch_task_info = + std::make_shared("op_name", 0, 2, std::vector{0, 1}, reinterpret_cast(1)); + std::shared_ptr task1 = TaskFactory::GetInstance().Create(model_context, label_switch_task_info); + std::shared_ptr task2 = TaskFactory::GetInstance().Create(model_context, label_switch_task_info); + auto label_switch_task_1 = std::dynamic_pointer_cast(task1); + auto label_switch_task_2 = std::dynamic_pointer_cast(task2); + ASSERT_TRUE(label_switch_task_1 != nullptr); + ASSERT_TRUE(label_switch_task_2 != nullptr); + ASSERT_NO_THROW(task1->Distribute()); + ASSERT_NO_THROW(task2->Distribute()); + ASSERT_EQ(label_switch_task_1->label_info_, label_switch_task_2->label_info_); +} + +TEST_F(TestAscendGeRuntime, test_memcpy_async_task_create_success) { + ModelContext model_context( + 0, 0, 0, reinterpret_cast(1), reinterpret_cast(2), {reinterpret_cast(1)}, + {reinterpret_cast(1), reinterpret_cast(1)}, {reinterpret_cast(1)}); + std::shared_ptr memcpy_task_info = std::make_shared( + "op_name", 0, reinterpret_cast(1), 2, reinterpret_cast(3), 4, 5, true); + std::shared_ptr task = TaskFactory::GetInstance().Create(model_context, memcpy_task_info); + ASSERT_TRUE(std::dynamic_pointer_cast(task) != nullptr); + ASSERT_NO_THROW(task->Distribute()); +} + +TEST_F(TestAscendGeRuntime, test_memcpy_async_task_create_invalid_stream_id_failed) { + ModelContext model_context( + 0, 0, 0, reinterpret_cast(1), reinterpret_cast(2), {reinterpret_cast(1)}, + {reinterpret_cast(1), reinterpret_cast(1)}, {reinterpret_cast(1)}); + std::shared_ptr memcpy_task_info = std::make_shared( + "op_name", 1, reinterpret_cast(1), 2, reinterpret_cast(3), 4, 5, true); + ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, memcpy_task_info)); +} + +TEST_F(TestAscendGeRuntime, test_profiler_task_create_success) { + ModelContext model_context( + 0, 0, 0, reinterpret_cast(1), reinterpret_cast(2), {reinterpret_cast(1)}, + {reinterpret_cast(1), reinterpret_cast(1)}, {reinterpret_cast(1)}); + std::shared_ptr profiler_task_info = std::make_shared("op_name", 0, 1, true, 2); + std::shared_ptr task = TaskFactory::GetInstance().Create(model_context, profiler_task_info); + ASSERT_TRUE(std::dynamic_pointer_cast(task) != nullptr); + ASSERT_NO_THROW(task->Distribute()); +} + +TEST_F(TestAscendGeRuntime, test_profiler_task_create_invalid_stream_id_failed) { + ModelContext model_context( + 0, 0, 0, reinterpret_cast(1), reinterpret_cast(2), {reinterpret_cast(1)}, + {reinterpret_cast(1), reinterpret_cast(1)}, {reinterpret_cast(1)}); + std::shared_ptr profiler_task_info = std::make_shared("op_name", 1, 1, true, 2); + ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, profiler_task_info)); +} + +TEST_F(TestAscendGeRuntime, test_stream_active_task_create_success) { + ModelContext model_context(0, 0, 0, reinterpret_cast(1), reinterpret_cast(2), + {reinterpret_cast(1), reinterpret_cast(2)}, + {reinterpret_cast(1), reinterpret_cast(1)}, + {reinterpret_cast(1)}); + std::shared_ptr stream_active_task_info = std::make_shared("op_name", 0, 1); + std::shared_ptr task = TaskFactory::GetInstance().Create(model_context, stream_active_task_info); + ASSERT_TRUE(std::dynamic_pointer_cast(task) != nullptr); + ASSERT_NO_THROW(task->Distribute()); +} + +TEST_F(TestAscendGeRuntime, test_stream_active_task_create_invalid_active_stream_id_failed) { + ModelContext model_context(0, 0, 0, reinterpret_cast(1), reinterpret_cast(2), + {reinterpret_cast(1), reinterpret_cast(2)}, + {reinterpret_cast(1), reinterpret_cast(1)}, + {reinterpret_cast(1)}); + std::shared_ptr stream_active_task_info = std::make_shared("op_name", 0, 2); + ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, stream_active_task_info)); +} + +TEST_F(TestAscendGeRuntime, test_stream_switch_task_create_success) { + ModelContext model_context(0, 0, 0, reinterpret_cast(1), reinterpret_cast(2), + {reinterpret_cast(1), reinterpret_cast(2)}, + {reinterpret_cast(1), reinterpret_cast(1)}, + {reinterpret_cast(1)}); + std::shared_ptr stream_switch_task_info = std::make_shared( + "op_name", 0, 1, reinterpret_cast(2), reinterpret_cast(3), 4, 5); + std::shared_ptr task = TaskFactory::GetInstance().Create(model_context, stream_switch_task_info); + ASSERT_TRUE(std::dynamic_pointer_cast(task) != nullptr); + ASSERT_NO_THROW(task->Distribute()); +} + +TEST_F(TestAscendGeRuntime, test_stream_switch_task_create_invalid_true_stream_id_failed) { + ModelContext model_context(0, 0, 0, reinterpret_cast(1), reinterpret_cast(2), + {reinterpret_cast(1), reinterpret_cast(2)}, + {reinterpret_cast(1), reinterpret_cast(1)}, + {reinterpret_cast(1)}); + std::shared_ptr stream_switch_task_info = std::make_shared( + "op_name", 0, 2, reinterpret_cast(2), reinterpret_cast(3), 4, 5); + std::shared_ptr task = TaskFactory::GetInstance().Create(model_context, stream_switch_task_info); + ASSERT_TRUE(std::dynamic_pointer_cast(task) != nullptr); + ASSERT_ANY_THROW(task->Distribute()); +} + +TEST_F(TestAscendGeRuntime, test_stream_switch_task_create_invalid_stream_id_failed) { + ModelContext model_context(0, 0, 0, reinterpret_cast(1), reinterpret_cast(2), + {reinterpret_cast(1), reinterpret_cast(2)}, + {reinterpret_cast(1), reinterpret_cast(1)}, + {reinterpret_cast(1)}); + std::shared_ptr stream_switch_task_info = std::make_shared( + "op_name", 2, 1, reinterpret_cast(2), reinterpret_cast(3), 4, 5); + ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, stream_switch_task_info)); +} + +TEST_F(TestAscendGeRuntime, test_tbe_task_create_success) { + ModelContext model_context(0, 0, 0, reinterpret_cast(1), reinterpret_cast(2), + {reinterpret_cast(1), reinterpret_cast(2)}, + {reinterpret_cast(1), reinterpret_cast(1)}, + {reinterpret_cast(1)}); + std::shared_ptr tbe_task_info = std::make_shared( + "op_name", 0, "stub_func", 1, std::vector(100, 2), 100, std::vector{5, 6}, + reinterpret_cast(7), 8, std::vector{9, 10}, + std::vector{reinterpret_cast(11), reinterpret_cast(12)}, + std::vector{reinterpret_cast(13), reinterpret_cast(14)}, + std::vector{reinterpret_cast(15), reinterpret_cast(16)}, true); + std::shared_ptr task = TaskFactory::GetInstance().Create(model_context, tbe_task_info); + auto tbe_task = std::dynamic_pointer_cast(task); + ASSERT_TRUE(tbe_task != nullptr); + ASSERT_NO_THROW(task->Distribute()); + tbe_task->args_ = new uint8_t[5]; +} + +TEST_F(TestAscendGeRuntime, test_tbe_task_create_invalid_stream_id_failed) { + ModelContext model_context(0, 0, 0, reinterpret_cast(1), reinterpret_cast(2), + {reinterpret_cast(1), reinterpret_cast(2)}, + {reinterpret_cast(1), reinterpret_cast(1)}, + {reinterpret_cast(1)}); + std::shared_ptr tbe_task_info = std::make_shared( + "op_name", 3, "stub_func", 1, std::vector(100, 2), 100, std::vector{5, 6}, + reinterpret_cast(7), 8, std::vector{9, 10}, + std::vector{reinterpret_cast(11), reinterpret_cast(12)}, + std::vector{reinterpret_cast(13), reinterpret_cast(14)}, + std::vector{reinterpret_cast(15), reinterpret_cast(16)}, true); + ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, tbe_task_info)); +} + +TEST_F(TestAscendGeRuntime, test_tbe_task_create_empty_stub_func_failed) { + ModelContext model_context(0, 0, 0, reinterpret_cast(1), reinterpret_cast(2), + {reinterpret_cast(1), reinterpret_cast(2)}, + {reinterpret_cast(1), reinterpret_cast(1)}, + {reinterpret_cast(1)}); + std::shared_ptr tbe_task_info = std::make_shared( + "op_name", 0, "", 1, std::vector(100, 2), 100, std::vector{5, 6}, reinterpret_cast(7), 8, + std::vector{9, 10}, std::vector{reinterpret_cast(11), reinterpret_cast(12)}, + std::vector{reinterpret_cast(13), reinterpret_cast(14)}, + std::vector{reinterpret_cast(15), reinterpret_cast(16)}, true); + std::shared_ptr task = TaskFactory::GetInstance().Create(model_context, tbe_task_info); + ASSERT_TRUE(std::dynamic_pointer_cast(task) != nullptr); + ASSERT_ANY_THROW(task->Distribute()); +} + +TEST_F(TestAscendGeRuntime, test_model_runner_success) { + constexpr uint32_t model_id = 0; + ModelContext model_context(0, 0, 0, reinterpret_cast(1), reinterpret_cast(2), + {reinterpret_cast(1), reinterpret_cast(2)}, + {reinterpret_cast(1), reinterpret_cast(1)}, + {reinterpret_cast(1)}); + std::shared_ptr tbe_task_info = std::make_shared( + "op_name", 0, "stub_func", 1, std::vector(100, 2), 100, std::vector{5, 6}, + reinterpret_cast(7), 8, std::vector{9, 10}, + std::vector{reinterpret_cast(11), reinterpret_cast(12)}, + std::vector{reinterpret_cast(13), reinterpret_cast(14)}, + std::vector{reinterpret_cast(15), reinterpret_cast(16)}, true); + std::shared_ptr aicpu_task_info = std::make_shared( + "op_name", 0, "so_name", "kernel_name", "node_def", "ext_info", std::vector{reinterpret_cast(1)}, + std::vector{reinterpret_cast(1)}, true); + auto davice_model = + std::make_shared(std::vector>{tbe_task_info, aicpu_task_info}, + std::vector{}, std::vector{}, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0); + ASSERT_NO_THROW(ModelRunner::Instance().LoadDavinciModel(0, 0, model_id, davice_model)); + auto iter = ModelRunner::Instance().runtime_models_.find(model_id); + ASSERT_TRUE(iter != ModelRunner::Instance().runtime_models_.end()); + auto &task_list = iter->second->task_list_; + task_list.clear(); + ASSERT_NO_THROW(task_list.emplace_back(TaskFactory::GetInstance().Create(model_context, tbe_task_info))); + ASSERT_NO_THROW(task_list.emplace_back(TaskFactory::GetInstance().Create(model_context, aicpu_task_info))); + ASSERT_NO_THROW(ModelRunner::Instance().DistributeTask(model_id)); + ASSERT_NO_THROW(ModelRunner::Instance().LoadModelComplete(model_id)); + ASSERT_NO_THROW(ModelRunner::Instance().RunModel(model_id)); + ASSERT_FALSE(ModelRunner::Instance().GetTaskIdList(model_id).empty()); + ASSERT_FALSE(ModelRunner::Instance().GetStreamIdList(model_id).empty()); + ASSERT_FALSE(ModelRunner::Instance().GetRuntimeInfoMap(model_id).empty()); + ASSERT_NO_THROW(ModelRunner::Instance().GetModelHandle(model_id)); + ASSERT_NO_THROW(ModelRunner::Instance().UnloadModel(model_id)); +} +} // namespace mindspore diff --git a/tests/ut/cpp/stub/ge/ge_task_launch_stub.cc b/tests/ut/cpp/stub/ge/ge_task_launch_stub.cc index 24fbde1fc94..18bd0929c0f 100644 --- a/tests/ut/cpp/stub/ge/ge_task_launch_stub.cc +++ b/tests/ut/cpp/stub/ge/ge_task_launch_stub.cc @@ -14,51 +14,8 @@ * limitations under the License. */ #include -#include "framework/ge_runtime/model_runner.h" #include "runtime/hccl_adapter/hccl_adapter.h" -namespace ge { -namespace model_runner { -ModelRunner &ModelRunner::Instance() { - static ModelRunner runner; - return runner; -} - -bool ModelRunner::LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint32_t model_id, - std::shared_ptr ascend_model, - std::shared_ptr listener) { - return true; -} - -bool ModelRunner::UnloadModel(uint32_t model_id) { return true; } - -bool ModelRunner::LoadModelComplete(uint32_t model_id) { return true; } - -bool ModelRunner::RunModel(uint32_t model_id, const ge::InputData &input_data, ge::OutputData *output_data) { - return true; -} - -void *ModelRunner::GetModelHandle(uint32_t model_id) const { return nullptr; } - -bool ModelRunner::DistributeTask(uint32_t model_id) { return true; } - -const std::vector &ModelRunner::GetTaskIdList(uint32_t model_id) const { - static std::vector task_id_list; - return task_id_list; -} - -const std::vector &ModelRunner::GetStreamIdList(uint32_t model_id) const { - static std::vector stream_id_list; - return stream_id_list; -} - -const std::map> &ModelRunner::GetRuntimeInfoMap(uint32_t model_id) const { - static std::map> runtime_info_map; - return runtime_info_map; -} -} // namespace model_runner -} // namespace ge - namespace mindspore { namespace hccl { bool InitHccl(uint32_t, std::string_view, std::string_view) { return true; } diff --git a/tests/ut/cpp/stub/runtime/runtime_stub.cc b/tests/ut/cpp/stub/runtime/runtime_stub.cc index d7a9876620c..8e7d749e58a 100644 --- a/tests/ut/cpp/stub/runtime/runtime_stub.cc +++ b/tests/ut/cpp/stub/runtime/runtime_stub.cc @@ -141,9 +141,9 @@ rtError_t rtGetFunctionByName(const char *stubName, void **stubFunc) { return RT rtError_t rtSetTaskGenCallback(rtTaskGenCallback callback) { return RT_ERROR_NONE; } -RTS_API rtError_t rtProfilerStart(uint64_t profConfig, int32_t numsDev, uint32_t* deviceList) { return RT_ERROR_NONE; } +RTS_API rtError_t rtProfilerStart(uint64_t profConfig, int32_t numsDev, uint32_t *deviceList) { return RT_ERROR_NONE; } -RTS_API rtError_t rtProfilerStop(uint64_t profConfig, int32_t numsDev, uint32_t* deviceList) { return RT_ERROR_NONE; } +RTS_API rtError_t rtProfilerStop(uint64_t profConfig, int32_t numsDev, uint32_t *deviceList) { return RT_ERROR_NONE; } int AdxDataDumpServerInit() { return 0; } @@ -151,11 +151,13 @@ int AdxDataDumpServerUnInit() { return 0; } RTS_API rtError_t rtGetTaskIdAndStreamID(uint32_t *taskid, uint32_t *streamid) { return RT_ERROR_NONE; } -RTS_API rtError_t rtSetTaskFailCallback(rtTaskFailCallback callback) {return RT_ERROR_NONE; } +RTS_API rtError_t rtSetTaskFailCallback(rtTaskFailCallback callback) { return RT_ERROR_NONE; } -RTS_API rtError_t rtRegDeviceStateCallback(const char *regName, rtDeviceStateCallback callback) {return RT_ERROR_NONE; } +RTS_API rtError_t rtRegDeviceStateCallback(const char *regName, rtDeviceStateCallback callback) { + return RT_ERROR_NONE; +} -RTS_API rtError_t rtSetMsprofReporterCallback(MsprofReporterCallback callback) {return RT_ERROR_NONE; } +RTS_API rtError_t rtSetMsprofReporterCallback(MsprofReporterCallback callback) { return RT_ERROR_NONE; } RTS_API rtError_t rtRegTaskFailCallbackByModule(const char *moduleName, rtTaskFailCallback callback) { return RT_ERROR_NONE; @@ -168,3 +170,28 @@ RTS_API rtError_t rtDevBinaryUnRegister(void *handle) { return RT_ERROR_NONE; } RTS_API rtError_t rtMemsetAsync(void *ptr, uint64_t destMax, uint32_t value, uint64_t count, rtStream_t stream) { return RT_ERROR_NONE; } + +RTS_API rtError_t rtLabelListCpy(rtLabel_t *label, uint32_t labelNumber, void *dst, uint32_t dstMax) { + return RT_ERROR_NONE; +} + +RTS_API rtError_t rtModelGetTaskId(rtModel_t model, uint32_t *taskid, uint32_t *streamid) { return RT_ERROR_NONE; } + +RTS_API rtError_t rtLabelCreateEx(rtLabel_t *label, rtStream_t stream) { return RT_ERROR_NONE; } + +RTS_API rtError_t rtCpuKernelLaunchWithFlag(const void *soName, const void *kernelName, uint32_t blockDim, + const void *args, uint32_t argsSize, rtSmDesc_t *smDesc, rtStream_t stream, + uint32_t flags) { + return RT_ERROR_NONE; +} + +RTS_API rtError_t rtLabelSwitchByIndex(void *ptr, uint32_t max, void *labelInfoPtr, rtStream_t stream) { + return RT_ERROR_NONE; +} + +RTS_API rtError_t rtProfilerTrace(uint64_t id, bool notify, uint32_t flags, rtStream_t stream) { return RT_ERROR_NONE; } + +RTS_API rtError_t rtKernelLaunchWithFlag(const void *stubFunc, uint32_t blockDim, void *args, uint32_t argsSize, + rtSmDesc_t *smDesc, rtStream_t stream, uint32_t flags) { + return RT_ERROR_NONE; +}