forked from mindspore-Ecosystem/mindspore
!15107 use internal ge runtime
From: @zhoufeng54 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
966f89198a
|
@ -252,7 +252,6 @@ if(NOT ENABLE_GE)
|
|||
FILES
|
||||
${CMAKE_BINARY_DIR}/graphengine/metadef/graph/libgraph.so
|
||||
${CMAKE_BINARY_DIR}/graphengine/ge/common/libge_common.so
|
||||
${CMAKE_BINARY_DIR}/graphengine/ge/ge_runtime/libge_runtime.so
|
||||
DESTINATION ${INSTALL_LIB_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
|
|
|
@ -313,7 +313,7 @@ if(ENABLE_D)
|
|||
target_link_options(ms_profile PRIVATE -Wl,-init,common_log_init)
|
||||
target_link_libraries(ms_profile -Wl,--start-group -Wl,--whole-archive ${PROFILING} -Wl,--no-whole-archive
|
||||
mindspore::protobuf -Wl,--end-group)
|
||||
target_link_libraries(mindspore ge_runtime ${CCE_LIB} ${RUNTIME_LIB} ${TSDCLIENT} ${HCCL} ${DATATRANSFER}
|
||||
target_link_libraries(mindspore ${CCE_LIB} ${RUNTIME_LIB} ${TSDCLIENT} ${HCCL} ${DATATRANSFER}
|
||||
${HCCL_ADPTER} ${REGISTER} -Wl,--no-as-needed ${OPTILING} ${HCCL_BUILDER}
|
||||
${HCCL_RA} ${PLATFORM} ${ACL})
|
||||
target_link_libraries(mindspore -Wl,--start-group proto_input mindspore::protobuf -Wl,--end-group)
|
||||
|
|
|
@ -30,7 +30,7 @@
|
|||
#include "runtime/device/kernel_runtime.h"
|
||||
#include "runtime/device/ascend/executor/host_dynamic_kernel.h"
|
||||
|
||||
using AicpuTaskInfoPtr = std::shared_ptr<ge::model_runner::AicpuTaskInfo>;
|
||||
using AicpuTaskInfoPtr = std::shared_ptr<mindspore::ge::model_runner::AicpuTaskInfo>;
|
||||
using AicpuDynamicKernel = mindspore::device::ascend::AiCpuDynamicKernel;
|
||||
using HostDynamicKernel = mindspore::device::ascend::HostDynamicKernel;
|
||||
|
||||
|
@ -193,9 +193,9 @@ std::vector<TaskInfoPtr> AicpuOpKernelMod::GenTask(const std::vector<AddressPtr>
|
|||
node_name_ = kPack;
|
||||
}
|
||||
|
||||
AicpuTaskInfoPtr task_info_ptr =
|
||||
make_shared<ge::model_runner::AicpuTaskInfo>(kernel_name_, stream_id, node_so_, node_name_, node_def_str_,
|
||||
ext_info_, input_data_addrs, output_data_addrs, NeedDump());
|
||||
AicpuTaskInfoPtr task_info_ptr = std::make_shared<mindspore::ge::model_runner::AicpuTaskInfo>(
|
||||
kernel_name_, stream_id, node_so_, node_name_, node_def_str_, ext_info_, input_data_addrs, output_data_addrs,
|
||||
NeedDump());
|
||||
|
||||
MS_LOG(INFO) << "AicpuOpKernelMod GenTask end";
|
||||
return {task_info_ptr};
|
||||
|
|
|
@ -29,7 +29,7 @@ using std::fstream;
|
|||
using std::map;
|
||||
using std::mutex;
|
||||
using std::string;
|
||||
using TbeTaskInfoPtr = std::shared_ptr<ge::model_runner::TbeTaskInfo>;
|
||||
using TbeTaskInfoPtr = std::shared_ptr<mindspore::ge::model_runner::TbeTaskInfo>;
|
||||
using tbe::KernelManager;
|
||||
constexpr uint32_t DEFAULT_BLOCK_DIM = 1;
|
||||
/**
|
||||
|
@ -118,7 +118,7 @@ std::vector<TaskInfoPtr> AkgKernelMod::GenTask(const std::vector<AddressPtr> &in
|
|||
|
||||
MS_LOG(DEBUG) << "The block_dim is:" << block_dim;
|
||||
|
||||
TbeTaskInfoPtr task_info_ptr = make_shared<ge::model_runner::TbeTaskInfo>(
|
||||
TbeTaskInfoPtr task_info_ptr = std::make_shared<mindspore::ge::model_runner::TbeTaskInfo>(
|
||||
kernel_name_, stream_id, stub_func, block_dim, args, args_size, sm_desc, binary, binary_size, meta_data,
|
||||
input_data_addrs, output_data_addrs, workspace_addrs, NeedDump());
|
||||
return {task_info_ptr};
|
||||
|
|
|
@ -19,11 +19,11 @@
|
|||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "framework/ge_runtime/task_info.h"
|
||||
#include "runtime/device/ascend/ge_runtime/task_info.h"
|
||||
#include "backend/kernel_compiler/kernel.h"
|
||||
#include "debug/data_dump/dump_json_parser.h"
|
||||
|
||||
using TaskInfoPtr = std::shared_ptr<ge::model_runner::TaskInfo>;
|
||||
using TaskInfoPtr = std::shared_ptr<mindspore::ge::model_runner::TaskInfo>;
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class AscendKernelMod : public KernelMod {
|
||||
|
|
|
@ -24,8 +24,8 @@
|
|||
#include "runtime/device/ascend/executor/hccl_dynamic_kernel.h"
|
||||
#include "runtime/hccl_adapter/hccl_adapter.h"
|
||||
|
||||
using HcclTaskInfoPtr = std::shared_ptr<ge::model_runner::HcclTaskInfo>;
|
||||
using ge::model_runner::HcclTaskInfo;
|
||||
using HcclTaskInfoPtr = std::shared_ptr<mindspore::ge::model_runner::HcclTaskInfo>;
|
||||
using mindspore::ge::model_runner::HcclTaskInfo;
|
||||
|
||||
namespace {
|
||||
static std::map<std::string, std::string> kMsOpNameToHcomHcclType = {
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#include <memory>
|
||||
#include "runtime/mem.h"
|
||||
|
||||
using ge::model_runner::MemcpyAsyncTaskInfo;
|
||||
using mindspore::ge::model_runner::MemcpyAsyncTaskInfo;
|
||||
using MemcpyAsyncTaskInfoPtr = std::shared_ptr<MemcpyAsyncTaskInfo>;
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
#include "framework/ge_runtime/task_info.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
|
||||
using ge::model_runner::LabelGotoTaskInfo;
|
||||
using mindspore::ge::model_runner::LabelGotoTaskInfo;
|
||||
using LabelGotoTaskInfoPtr = std::shared_ptr<LabelGotoTaskInfo>;
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
#include "framework/ge_runtime/task_info.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
|
||||
using ge::model_runner::LabelSetTaskInfo;
|
||||
using mindspore::ge::model_runner::LabelSetTaskInfo;
|
||||
using LabelSetTaskInfoPtr = std::shared_ptr<LabelSetTaskInfo>;
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include "framework/ge_runtime/task_info.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
|
||||
using ge::model_runner::LabelSwitchTaskInfo;
|
||||
using mindspore::ge::model_runner::LabelSwitchTaskInfo;
|
||||
using LabelSwitchTaskInfoPtr = std::shared_ptr<LabelSwitchTaskInfo>;
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
#include "runtime/device/kernel_runtime.h"
|
||||
#include "runtime/device/ascend/executor/rts/memcpy_rts_dynamic_kernel.h"
|
||||
|
||||
using ge::model_runner::MemcpyAsyncTaskInfo;
|
||||
using mindspore::ge::model_runner::MemcpyAsyncTaskInfo;
|
||||
using MemcpyAsyncTaskInfoPtr = std::shared_ptr<MemcpyAsyncTaskInfo>;
|
||||
using AddressPtrList = std::vector<mindspore::kernel::AddressPtr>;
|
||||
using mindspore::device::ascend::MemcpyRtsDynamicKernel;
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "runtime/device/ascend/executor/rts/profiling_rts_dynamic_kernel.h"
|
||||
|
||||
using ProfilerTraceTaskInfo = ge::model_runner::ProfilerTraceTaskInfo;
|
||||
using ProfilerTraceTaskInfo = mindspore::ge::model_runner::ProfilerTraceTaskInfo;
|
||||
using mindspore::device::ascend::ProfilingRtsDynamicKernel;
|
||||
using mindspore::device::ascend::ProfilingUtils;
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
using ge::model_runner::EventWaitTaskInfo;
|
||||
using mindspore::ge::model_runner::EventWaitTaskInfo;
|
||||
using EventWaitTaskInfoPtr = std::shared_ptr<EventWaitTaskInfo>;
|
||||
|
||||
RecvKernel::RecvKernel() { event_id_ = 0; }
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
#include "framework/ge_runtime/task_info.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
|
||||
using ge::model_runner::EventRecordTaskInfo;
|
||||
using mindspore::ge::model_runner::EventRecordTaskInfo;
|
||||
using EventRecordTaskInfoPtr = std::shared_ptr<EventRecordTaskInfo>;
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
#include "framework/ge_runtime/task_info.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
|
||||
using ge::model_runner::StreamActiveTaskInfo;
|
||||
using mindspore::ge::model_runner::StreamActiveTaskInfo;
|
||||
using StreamActiveTaskInfoPtr = std::shared_ptr<StreamActiveTaskInfo>;
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include "framework/ge_runtime/task_info.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
|
||||
using ge::model_runner::StreamSwitchTaskInfo;
|
||||
using mindspore::ge::model_runner::StreamSwitchTaskInfo;
|
||||
using StreamSwitchTaskInfoPtr = std::shared_ptr<StreamSwitchTaskInfo>;
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
using TbeTaskInfoPtr = std::shared_ptr<ge::model_runner::TbeTaskInfo>;
|
||||
using TbeTaskInfoPtr = std::shared_ptr<mindspore::ge::model_runner::TbeTaskInfo>;
|
||||
using tbe::KernelManager;
|
||||
using AddressPtrList = std::vector<mindspore::kernel::AddressPtr>;
|
||||
bool TbeKernelMod::Launch(const std::vector<mindspore::kernel::AddressPtr> &inputs,
|
||||
|
@ -102,7 +102,7 @@ std::vector<TaskInfoPtr> TbeKernelMod::GenTask(const std::vector<AddressPtr> &in
|
|||
|
||||
MS_LOG(INFO) << "block_dim is:" << block_dim_;
|
||||
|
||||
TbeTaskInfoPtr task_info_ptr = make_shared<ge::model_runner::TbeTaskInfo>(
|
||||
TbeTaskInfoPtr task_info_ptr = std::make_shared<mindspore::ge::model_runner::TbeTaskInfo>(
|
||||
kernel_name_, stream_id, stub_func, block_dim_, args, 0, sm_desc, nullptr, 0, meta_data, input_data_addrs,
|
||||
output_data_addrs, workspace_addrs, NeedDump());
|
||||
return {task_info_ptr};
|
||||
|
|
|
@ -36,7 +36,7 @@ using mindspore::kernel::tbe::TbeUtils;
|
|||
bool TbeOpParallelBuild(const std::vector<AnfNodePtr> &anf_nodes) {
|
||||
auto build_manger = std::make_shared<ParallelBuildManager>();
|
||||
MS_EXCEPTION_IF_NULL(build_manger);
|
||||
static set<std::string> processed_kernel;
|
||||
static std::set<std::string> processed_kernel;
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
auto tune_mode = context_ptr->get_param<std::string>(MS_CTX_TUNE_MODE);
|
||||
|
@ -259,8 +259,8 @@ bool ParallelBuildManager::SearchInCache(const std::string &json_name, const std
|
|||
}
|
||||
|
||||
KernelModPtr ParallelBuildManager::GenKernelMod(const string &json_name, const string &processor,
|
||||
const vector<size_t> &input_size_list,
|
||||
const vector<size_t> &output_size_list,
|
||||
const std::vector<size_t> &input_size_list,
|
||||
const std::vector<size_t> &output_size_list,
|
||||
const mindspore::kernel::KernelPackPtr &kernel_pack) const {
|
||||
MS_EXCEPTION_IF_NULL(kernel_pack);
|
||||
auto kernel_json_info = kernel_pack->kernel_json_info();
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "proto/tensor_shape.pb.h"
|
||||
#include "proto/attr.pb.h"
|
||||
#include "proto/node_def.pb.h"
|
||||
#include "runtime/rt.h"
|
||||
|
||||
using mindspore::kernel::Address;
|
||||
using AddressPtr = std::shared_ptr<Address>;
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "ps/ps_cache/ps_cache_basic.h"
|
||||
#include "backend/kernel_compiler/aicpu/aicpu_kernel_mod.h"
|
||||
#include "ir/dtype.h"
|
||||
#include "runtime/base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
|
|
|
@ -79,3 +79,9 @@ list(REMOVE_ITEM D_SRC_LIST "ascend/profiling/profiling_callback_register.cc")
|
|||
set_property(SOURCE ${DEVICE_SRC_LIST} ${D_SRC_LIST} ${CPU_SRC_LIST}
|
||||
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE)
|
||||
add_library(_mindspore_runtime_device_obj OBJECT ${DEVICE_SRC_LIST} ${D_SRC_LIST} ${CPU_SRC_LIST} ${TDT_SRC_LIST})
|
||||
if(ENABLE_D)
|
||||
file(GLOB_RECURSE GE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ascend/ge_runtime/*.cc")
|
||||
set_property(SOURCE ${GE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_GE)
|
||||
target_include_directories(_mindspore_runtime_device_obj PRIVATE ${CMAKE_BINARY_DIR}/proto/ge)
|
||||
add_dependencies(_mindspore_runtime_device_obj graph)
|
||||
endif()
|
||||
|
|
|
@ -28,9 +28,9 @@
|
|||
#include "utils/mpi/mpi_config.h"
|
||||
#include "runtime/device/ascend/profiling/profiling_manager.h"
|
||||
#include "common/trans.h"
|
||||
#include "runtime/context.h"
|
||||
#include "runtime/rt.h"
|
||||
#include "runtime/device/ascend/ascend_stream_assign.h"
|
||||
#include "framework/ge_runtime/model_runner.h"
|
||||
#include "runtime/device/ascend/ge_runtime/model_runner.h"
|
||||
#include "runtime/device/ascend/tasksink/task_generator.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "runtime/device/ascend/profiling/profiling_utils.h"
|
||||
|
@ -40,7 +40,6 @@
|
|||
#include "toolchain/adx_datadump_server.h"
|
||||
#include "utils/trace_base.h"
|
||||
#include "graphengine/inc/external/acl/error_codes/rt_error_codes.h"
|
||||
#include "utils/runtime_error_codes.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
#ifdef MEM_REUSE_DEBUG
|
||||
#include "backend/optimizer/mem_reuse/mem_reuse_checker.h"
|
||||
|
@ -61,10 +60,10 @@ using mindspore::dataset::TdtHandle;
|
|||
#include "debug/rdr/running_data_recorder.h"
|
||||
#endif
|
||||
|
||||
using ge::model_runner::ModelRunner;
|
||||
using mindspore::device::ascend::ProfilingManager;
|
||||
using mindspore::device::ascend::ProfilingUtils;
|
||||
using mindspore::device::ascend::tasksink::TaskGenerator;
|
||||
using mindspore::ge::model_runner::ModelRunner;
|
||||
using mindspore::kernel::tbe::TbeUtils;
|
||||
using std::vector;
|
||||
|
||||
|
@ -158,10 +157,7 @@ void AscendKernelRuntime::ClearGraphModelMap() {
|
|||
graph_kernel_events_map_.clear();
|
||||
for (auto &iter : graph_model_map_) {
|
||||
MS_LOG(INFO) << "Ge UnloadModel " << iter.first;
|
||||
auto ret = ModelRunner::Instance().UnloadModel(iter.first);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "UnloadModel failed";
|
||||
}
|
||||
ModelRunner::Instance().UnloadModel(iter.first);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -194,10 +190,7 @@ void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std
|
|||
MS_LOG(DEBUG) << "Clear graph:" << graph_id << " runtime resource";
|
||||
if (auto model_iter = graph_model_map_.find(graph_id); model_iter != graph_model_map_.end()) {
|
||||
MS_LOG(DEBUG) << "Ge UnloadModel " << graph_id;
|
||||
auto ret = ModelRunner::Instance().UnloadModel(graph_id);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "UnloadModel failed";
|
||||
}
|
||||
ModelRunner::Instance().UnloadModel(graph_id);
|
||||
graph_model_map_.erase(model_iter);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "GraphId:" << graph_id << " not found";
|
||||
|
@ -482,10 +475,9 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
|
|||
<< ", total label num:" << graph->label_num()
|
||||
<< ", wait_active_stream_list size:" << wait_active_stream_list.size()
|
||||
<< ", force_copy_stream_list size:" << force_copy_stream_list.size();
|
||||
std::vector<std::shared_ptr<ge::model_runner::OpInfo>> empty_list;
|
||||
auto model = std::make_shared<ge::model_runner::DavinciModel>(
|
||||
task_info_list, empty_list, empty_list, empty_list, empty_list, wait_active_stream_list, force_copy_stream_list, 0,
|
||||
0, 0, 0, 0, 0, resource_manager.get_cur_stream_num(), graph->label_num(), resource_manager.get_cur_event_num(), 0);
|
||||
task_info_list, wait_active_stream_list, force_copy_stream_list, 0, 0, 0, 0, 0, 0,
|
||||
resource_manager.get_cur_stream_num(), graph->label_num(), resource_manager.get_cur_event_num(), 0);
|
||||
auto ret = graph_model_map_.insert(std::make_pair(graph->graph_id(), model));
|
||||
if (!ret.second) {
|
||||
MS_LOG(EXCEPTION) << "Duplicate GraphId! Please check in ascend_session.";
|
||||
|
@ -514,24 +506,20 @@ bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) {
|
|||
return false;
|
||||
}
|
||||
|
||||
std::shared_ptr<ge::ModelListener> listener;
|
||||
MS_LOG(INFO) << "LoadDavinciModel mode_id:" << model_iter->first;
|
||||
bool status =
|
||||
ModelRunner::Instance().LoadDavinciModel(device_id_, 0, model_iter->first, model_iter->second, listener);
|
||||
if (!status) {
|
||||
MS_LOG(EXCEPTION) << "Load Model Failed";
|
||||
}
|
||||
ModelRunner::Instance().LoadDavinciModel(device_id_, 0, model_iter->first, model_iter->second);
|
||||
|
||||
std::function<void *()> model_handle =
|
||||
std::bind(&ModelRunner::GetModelHandle, &ModelRunner::Instance(), model_iter->first);
|
||||
DistributeDebugTask(NOT_NULL(graph), NOT_NULL(model_handle));
|
||||
|
||||
status = ModelRunner::Instance().DistributeTask(model_iter->first);
|
||||
if (!status) {
|
||||
try {
|
||||
ModelRunner::Instance().DistributeTask(model_iter->first);
|
||||
} catch (const std::exception &e) {
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
mindspore::RDR::TriggerAll();
|
||||
#endif
|
||||
MS_LOG(EXCEPTION) << "Distribute Task Failed";
|
||||
MS_LOG(EXCEPTION) << "Distribute Task Failed, error: " << e.what();
|
||||
}
|
||||
|
||||
if (ProfilingManager::GetInstance().IsProfiling()) {
|
||||
|
@ -542,10 +530,7 @@ bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) {
|
|||
|
||||
LaunchDataDump(graph->graph_id());
|
||||
|
||||
if (!ModelRunner::Instance().LoadModelComplete(model_iter->first)) {
|
||||
MS_LOG(ERROR) << "Call ge runtime LoadModelComplete failed";
|
||||
return false;
|
||||
}
|
||||
ModelRunner::Instance().LoadModelComplete(model_iter->first);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -730,8 +715,6 @@ bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) {
|
|||
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
ge::InputData input_tensors = ge::InputData();
|
||||
ge::OutputData *output_tensors = nullptr;
|
||||
if (GraphWithEmptyTaskList(graph)) {
|
||||
MS_LOG(WARNING) << "RunTask end, no task info found";
|
||||
return true;
|
||||
|
@ -742,8 +725,9 @@ bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) {
|
|||
return false;
|
||||
}
|
||||
|
||||
bool status = ModelRunner::Instance().RunModel(graph->graph_id(), input_tensors, output_tensors);
|
||||
if (!status) {
|
||||
try {
|
||||
ModelRunner::Instance().RunModel(graph->graph_id());
|
||||
} catch (const std::exception &) {
|
||||
DumpTaskExceptionInfo(graph);
|
||||
std::string file_name = "task_error_debug" + std::to_string(graph->graph_id()) + ".ir";
|
||||
auto graph_tmp = std::make_shared<session::KernelGraph>(*graph);
|
||||
|
@ -988,7 +972,7 @@ void AscendKernelRuntime::KernelLaunchProfiling(const std::string &kernel_name)
|
|||
}
|
||||
|
||||
uint64_t AscendKernelRuntime::GetAvailableMemMaxSize() const {
|
||||
auto ascend_mem_manager = dynamic_pointer_cast<AscendMemoryManager>(mem_manager_);
|
||||
auto ascend_mem_manager = std::dynamic_pointer_cast<AscendMemoryManager>(mem_manager_);
|
||||
return ascend_mem_manager->GetDeviceMemSize();
|
||||
}
|
||||
|
||||
|
|
|
@ -25,15 +25,15 @@
|
|||
#include <unordered_set>
|
||||
#include "runtime/device/kernel_runtime.h"
|
||||
#include "runtime/context.h"
|
||||
#include "framework/ge_runtime/davinci_model.h"
|
||||
#include "runtime/device/ascend/ge_runtime/davinci_model.h"
|
||||
#include "runtime/device/kernel_runtime_manager.h"
|
||||
#include "backend/session/session_basic.h"
|
||||
#include "runtime/device/ascend/dump/data_dumper.h"
|
||||
|
||||
using ge::model_runner::TaskInfo;
|
||||
using std::unordered_map;
|
||||
using std::vector;
|
||||
namespace mindspore::device::ascend {
|
||||
using ge::model_runner::TaskInfo;
|
||||
class AscendKernelRuntime : public KernelRuntime {
|
||||
public:
|
||||
AscendKernelRuntime() = default;
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include <algorithm>
|
||||
#include "runtime/device/ascend/ascend_memory_pool.h"
|
||||
#include "runtime/mem.h"
|
||||
#include "runtime/device/ascend/ascend_kernel_runtime.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
|
|
|
@ -0,0 +1,92 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_DAVINCI_MODEL_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_DAVINCI_MODEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "runtime/device/ascend/ge_runtime/task_info.h"
|
||||
|
||||
namespace mindspore::ge::model_runner {
|
||||
class DavinciModel {
|
||||
public:
|
||||
DavinciModel(const std::vector<std::shared_ptr<TaskInfo>> &task_info_list,
|
||||
const std::vector<uint32_t> &wait_active_stream_list,
|
||||
const std::vector<uint32_t> &force_copy_stream_list, uint64_t mem_size = 0, uint64_t weight_size = 0,
|
||||
uint64_t var_size = 0, uintptr_t logic_mem_base = 0, uintptr_t logic_weight_base = 0,
|
||||
uintptr_t logic_var_base = 0, uint32_t stream_num = 0, uint32_t batch_num = 0, uint32_t event_num = 0,
|
||||
int32_t priority = 0)
|
||||
: task_info_list_(task_info_list),
|
||||
wait_active_stream_list_(wait_active_stream_list),
|
||||
force_copy_stream_list_(force_copy_stream_list),
|
||||
mem_size_(mem_size),
|
||||
weight_size_(weight_size),
|
||||
var_size_(var_size),
|
||||
logic_mem_base_(logic_mem_base),
|
||||
logic_weight_base_(logic_weight_base),
|
||||
logic_var_base_(logic_var_base),
|
||||
stream_num_(stream_num),
|
||||
batch_num_(batch_num),
|
||||
event_num_(event_num),
|
||||
priority_(priority) {}
|
||||
~DavinciModel() {}
|
||||
|
||||
uint64_t GetMemSize() const { return mem_size_; }
|
||||
uint64_t GetWeightSize() const { return weight_size_; }
|
||||
uint64_t GetVarSize() const { return var_size_; }
|
||||
|
||||
uintptr_t GetLogicMemBase() const { return logic_mem_base_; }
|
||||
uintptr_t GetLogicWeightBase() const { return logic_weight_base_; }
|
||||
uintptr_t GetLogicVarBase() const { return logic_var_base_; }
|
||||
|
||||
uint32_t GetStreamNum() const { return stream_num_; }
|
||||
uint32_t GetBatchNum() const { return batch_num_; }
|
||||
uint32_t GetEventNum() const { return event_num_; }
|
||||
|
||||
const std::vector<uint32_t> &GetWaitActiveStreams() const { return wait_active_stream_list_; }
|
||||
const std::vector<uint32_t> &GetForceCopyStreams() const { return force_copy_stream_list_; }
|
||||
|
||||
int32_t GetPriority() const { return priority_; }
|
||||
|
||||
const std::vector<std::shared_ptr<TaskInfo>> &GetTaskInfoList() const { return task_info_list_; }
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<TaskInfo>> task_info_list_;
|
||||
|
||||
std::vector<uint32_t> wait_active_stream_list_;
|
||||
std::vector<uint32_t> force_copy_stream_list_;
|
||||
|
||||
uint64_t mem_size_;
|
||||
uint64_t weight_size_;
|
||||
uint64_t var_size_;
|
||||
|
||||
uintptr_t logic_mem_base_;
|
||||
uintptr_t logic_weight_base_;
|
||||
uintptr_t logic_var_base_;
|
||||
|
||||
uint32_t stream_num_;
|
||||
uint32_t batch_num_;
|
||||
uint32_t event_num_;
|
||||
|
||||
int32_t priority_;
|
||||
|
||||
// Disable to copy constructor and assignment operator
|
||||
DavinciModel &operator=(const DavinciModel &) = delete;
|
||||
DavinciModel(const DavinciModel &) = delete;
|
||||
};
|
||||
} // namespace mindspore::ge::model_runner
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_DAVINCI_MODEL_H_
|
|
@ -0,0 +1,59 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_MODEL_CONTEXT_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_MODEL_CONTEXT_H_
|
||||
|
||||
#include <vector>
|
||||
#include "runtime/rt_model.h"
|
||||
|
||||
namespace mindspore::ge::model_runner {
|
||||
class ModelContext {
|
||||
public:
|
||||
ModelContext(uint32_t device_id, uint64_t session_id, int32_t priority, rtModel_t rt_model_handle,
|
||||
rtStream_t rt_model_stream, const std::vector<rtStream_t> &stream_list,
|
||||
const std::vector<rtLabel_t> &label_list, const std::vector<rtEvent_t> &event_list)
|
||||
: device_id_(device_id),
|
||||
session_id_(session_id),
|
||||
priority_(priority),
|
||||
rt_model_handle_(rt_model_handle),
|
||||
rt_model_stream_(rt_model_stream),
|
||||
stream_list_(stream_list),
|
||||
label_list_(label_list),
|
||||
event_list_(event_list) {}
|
||||
~ModelContext() {}
|
||||
|
||||
uint64_t device_id() const { return device_id_; }
|
||||
uint64_t session_id() const { return session_id_; }
|
||||
int32_t priority() const { return priority_; }
|
||||
const rtModel_t &rt_model_handle() const { return rt_model_handle_; }
|
||||
const rtStream_t &rt_model_stream() const { return rt_model_stream_; }
|
||||
const std::vector<rtStream_t> &stream_list() const { return stream_list_; }
|
||||
const std::vector<rtLabel_t> &label_list() const { return label_list_; }
|
||||
const std::vector<rtEvent_t> &event_list() const { return event_list_; }
|
||||
|
||||
private:
|
||||
uint32_t device_id_;
|
||||
uint64_t session_id_;
|
||||
int32_t priority_;
|
||||
rtModel_t rt_model_handle_;
|
||||
rtStream_t rt_model_stream_;
|
||||
std::vector<rtStream_t> stream_list_;
|
||||
std::vector<rtLabel_t> label_list_;
|
||||
std::vector<rtEvent_t> event_list_;
|
||||
};
|
||||
} // namespace mindspore::ge::model_runner
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_MODEL_CONTEXT_H_
|
|
@ -0,0 +1,104 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/device/ascend/ge_runtime/model_runner.h"
|
||||
#include "runtime/device/ascend/ge_runtime/runtime_model.h"
|
||||
#include "runtime/device/ascend/ge_runtime/davinci_model.h"
|
||||
#include "mindspore/core/utils/log_adapter.h"
|
||||
|
||||
namespace mindspore::ge::model_runner {
|
||||
ModelRunner &ModelRunner::Instance() {
|
||||
static ModelRunner instance; // Guaranteed to be destroyed.
|
||||
return instance;
|
||||
}
|
||||
|
||||
void ModelRunner::LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint32_t model_id,
|
||||
const std::shared_ptr<DavinciModel> &davinci_model) {
|
||||
std::shared_ptr<RuntimeModel> model = std::make_shared<RuntimeModel>();
|
||||
model->Load(device_id, session_id, davinci_model);
|
||||
runtime_models_[model_id] = model;
|
||||
}
|
||||
|
||||
void ModelRunner::DistributeTask(uint32_t model_id) {
|
||||
auto model_iter = runtime_models_.find(model_id);
|
||||
if (model_iter == runtime_models_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Model id " << model_id << " not found.";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(model_iter->second);
|
||||
model_iter->second->DistributeTask();
|
||||
}
|
||||
|
||||
void ModelRunner::LoadModelComplete(uint32_t model_id) {
|
||||
auto model_iter = runtime_models_.find(model_id);
|
||||
if (model_iter == runtime_models_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Model id " << model_id << " not found.";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(model_iter->second);
|
||||
model_iter->second->LoadComplete();
|
||||
}
|
||||
|
||||
const std::vector<uint32_t> &ModelRunner::GetTaskIdList(uint32_t model_id) const {
|
||||
auto model_iter = runtime_models_.find(model_id);
|
||||
if (model_iter == runtime_models_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Model id " << model_id << " not found.";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(model_iter->second);
|
||||
return model_iter->second->GetTaskIdList();
|
||||
}
|
||||
|
||||
const std::vector<uint32_t> &ModelRunner::GetStreamIdList(uint32_t model_id) const {
|
||||
auto model_iter = runtime_models_.find(model_id);
|
||||
if (model_iter == runtime_models_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Model id " << model_id << " not found.";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(model_iter->second);
|
||||
return model_iter->second->GetStreamIdList();
|
||||
}
|
||||
|
||||
const std::map<std::string, std::shared_ptr<RuntimeInfo>> &ModelRunner::GetRuntimeInfoMap(uint32_t model_id) const {
|
||||
auto model_iter = runtime_models_.find(model_id);
|
||||
if (model_iter == runtime_models_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Model id " << model_id << " not found.";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(model_iter->second);
|
||||
return model_iter->second->GetRuntimeInfoMap();
|
||||
}
|
||||
|
||||
void *ModelRunner::GetModelHandle(uint32_t model_id) const {
|
||||
auto model_iter = runtime_models_.find(model_id);
|
||||
if (model_iter == runtime_models_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Model id " << model_id << " not found.";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(model_iter->second);
|
||||
return model_iter->second->GetModelHandle();
|
||||
}
|
||||
|
||||
void ModelRunner::UnloadModel(uint32_t model_id) {
|
||||
auto iter = runtime_models_.find(model_id);
|
||||
if (iter != runtime_models_.end()) {
|
||||
(void)runtime_models_.erase(iter);
|
||||
}
|
||||
}
|
||||
|
||||
void ModelRunner::RunModel(uint32_t model_id) {
|
||||
auto model_iter = runtime_models_.find(model_id);
|
||||
if (model_iter == runtime_models_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Model id " << model_id << " not found.";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(model_iter->second);
|
||||
model_iter->second->Run();
|
||||
}
|
||||
} // namespace mindspore::ge::model_runner
|
|
@ -0,0 +1,60 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_MODEL_RUNNER_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_MODEL_RUNNER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <tuple>
|
||||
#include <string>
|
||||
#include "runtime/device/ascend/ge_runtime/davinci_model.h"
|
||||
|
||||
namespace mindspore::ge::model_runner {
|
||||
class RuntimeModel;
|
||||
using RuntimeInfo = std::tuple<uint32_t, uint32_t, void *>;
|
||||
class ModelRunner {
|
||||
public:
|
||||
static ModelRunner &Instance();
|
||||
|
||||
void LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint32_t model_id,
|
||||
const std::shared_ptr<DavinciModel> &davinci_model);
|
||||
|
||||
void DistributeTask(uint32_t model_id);
|
||||
|
||||
void LoadModelComplete(uint32_t model_id);
|
||||
|
||||
const std::vector<uint32_t> &GetTaskIdList(uint32_t model_id) const;
|
||||
|
||||
const std::vector<uint32_t> &GetStreamIdList(uint32_t model_id) const;
|
||||
|
||||
const std::map<std::string, std::shared_ptr<RuntimeInfo>> &GetRuntimeInfoMap(uint32_t model_id) const;
|
||||
|
||||
void *GetModelHandle(uint32_t model_id) const;
|
||||
|
||||
void UnloadModel(uint32_t model_id);
|
||||
|
||||
void RunModel(uint32_t model_id);
|
||||
|
||||
private:
|
||||
ModelRunner() = default;
|
||||
~ModelRunner() = default;
|
||||
|
||||
std::map<uint32_t, std::shared_ptr<RuntimeModel>> runtime_models_;
|
||||
};
|
||||
} // namespace mindspore::ge::model_runner
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_MODEL_RUNNER_H_
|
|
@ -0,0 +1,292 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/device/ascend/ge_runtime/runtime_model.h"
|
||||
#include <set>
|
||||
#include "runtime/kernel.h"
|
||||
#include "runtime/rt_model.h"
|
||||
#include "graphengine/inc/external/runtime/rt_error_codes.h"
|
||||
#include "runtime/device/ascend/ge_runtime/model_context.h"
|
||||
#include "runtime/device/ascend/ge_runtime/task/task.h"
|
||||
#include "runtime/device/ascend/ge_runtime/task/task_factory.h"
|
||||
#include "mindspore/core/utils/log_adapter.h"
|
||||
|
||||
namespace mindspore::ge::model_runner {
|
||||
RuntimeModel::~RuntimeModel() {
|
||||
MS_LOG(INFO) << "RuntimeModel destructor start.";
|
||||
|
||||
// Unbind rtModel from all task related streams
|
||||
RtModelUnbindStream();
|
||||
|
||||
// Release task first, hccl task hold stream
|
||||
task_list_.clear();
|
||||
|
||||
// Release all task related streams
|
||||
RtStreamDestory();
|
||||
|
||||
// Release rtlabel resource
|
||||
RtLabelDestory();
|
||||
|
||||
// Release rtEvent resourece
|
||||
RtEventDestory();
|
||||
|
||||
MS_LOG(INFO) << "Do RtModelDestroy";
|
||||
// Release all rt_model
|
||||
RtModelDestory();
|
||||
}
|
||||
|
||||
void RuntimeModel::InitStream(const std::shared_ptr<DavinciModel> &davinci_model) {
|
||||
MS_EXCEPTION_IF_NULL(davinci_model);
|
||||
|
||||
std::set<int64_t> wait_active_streams;
|
||||
std::set<int64_t> force_copy_streams;
|
||||
|
||||
for (const auto &stream_id : davinci_model->GetWaitActiveStreams()) {
|
||||
MS_LOG(INFO) << "Stream id " << stream_id << " is wait active stream.";
|
||||
(void)wait_active_streams.insert(stream_id);
|
||||
}
|
||||
|
||||
for (const auto &stream_id : davinci_model->GetForceCopyStreams()) {
|
||||
MS_LOG(INFO) << "Stream id " << stream_id << " is force copy stream.";
|
||||
(void)force_copy_streams.insert(stream_id);
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Total stream num " << davinci_model->GetStreamNum();
|
||||
for (uint32_t i = 0; i < davinci_model->GetStreamNum(); ++i) {
|
||||
rtStream_t stream = nullptr;
|
||||
uint32_t flag = (force_copy_streams.find(i) != force_copy_streams.end())
|
||||
? (RT_STREAM_PERSISTENT | RT_STREAM_FORCE_COPY)
|
||||
: (RT_STREAM_PERSISTENT);
|
||||
|
||||
rtError_t rt_ret = rtStreamCreateWithFlags(&stream, davinci_model->GetPriority(), flag);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rt api rtStreamCreate failed, ret: " << std::hex << rt_ret;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "rtStreamCreateWithFlags end.";
|
||||
stream_list_.emplace_back(stream);
|
||||
|
||||
// Bind rt_model_handle_ to all task related streams
|
||||
flag = (wait_active_streams.find(i) != wait_active_streams.end()) ? (static_cast<uint32_t>(RT_INVALID_FLAG))
|
||||
: (static_cast<uint32_t>(RT_HEAD_STREAM));
|
||||
rt_ret = rtModelBindStream(rt_model_handle_, stream, flag);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rt api rtModelBindStream failed, ret: " << std::hex << rt_ret;
|
||||
}
|
||||
MS_LOG(INFO) << "stream index: " << i << ", stream: " << std::hex << stream;
|
||||
}
|
||||
}
|
||||
|
||||
void RuntimeModel::InitEvent(uint32_t event_num) {
|
||||
MS_LOG(INFO) << "Event number: " << event_num;
|
||||
for (uint32_t i = 0; i < event_num; ++i) {
|
||||
rtEvent_t rt_event;
|
||||
rtError_t rt_ret = rtEventCreate(&rt_event);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rt api rtEventCreate failed, ret: " << std::hex << rt_ret;
|
||||
}
|
||||
event_list_.push_back(rt_event);
|
||||
}
|
||||
}
|
||||
|
||||
void RuntimeModel::InitLabel(const std::shared_ptr<DavinciModel> &davinci_model) {
|
||||
MS_LOG(INFO) << "Label number: " << davinci_model->GetBatchNum();
|
||||
label_list_.resize(davinci_model->GetBatchNum());
|
||||
for (auto &task_info : davinci_model->GetTaskInfoList()) {
|
||||
MS_EXCEPTION_IF_NULL(task_info);
|
||||
|
||||
if (task_info->type() != TaskInfoType::LABEL_SET) {
|
||||
continue;
|
||||
}
|
||||
auto label_set_task_info = std::static_pointer_cast<LabelSetTaskInfo>(task_info);
|
||||
|
||||
if (label_set_task_info->stream_id() >= stream_list_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid stream id " << label_set_task_info->stream_id() << " total stream num "
|
||||
<< stream_list_.size();
|
||||
}
|
||||
|
||||
rtLabel_t rt_label = nullptr;
|
||||
rtError_t rt_ret = rtLabelCreateEx(&rt_label, stream_list_[label_set_task_info->stream_id()]);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rt api rtLabelCreate failed, ret: " << std::hex << rt_ret;
|
||||
}
|
||||
label_list_[label_set_task_info->label_id()] = rt_label;
|
||||
}
|
||||
}
|
||||
|
||||
void RuntimeModel::InitResource(const std::shared_ptr<DavinciModel> &davinci_model) {
|
||||
MS_LOG(INFO) << "InitResource start";
|
||||
MS_EXCEPTION_IF_NULL(davinci_model);
|
||||
|
||||
rtError_t rt_ret = rtModelCreate(&rt_model_handle_, 0);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rt api rtModelCreate failed, ret: " << std::hex << rt_ret;
|
||||
}
|
||||
|
||||
// Create rtStream for rt_model_handle_
|
||||
rt_ret = rtStreamCreate(&rt_model_stream_, davinci_model->GetPriority());
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rt api rtStreamCreate failed, ret: " << std::hex << rt_ret;
|
||||
}
|
||||
MS_LOG(INFO) << "rtStreamCreate end";
|
||||
|
||||
InitStream(davinci_model);
|
||||
InitEvent(davinci_model->GetEventNum());
|
||||
InitLabel(davinci_model);
|
||||
|
||||
MS_LOG(INFO) << "InitResource success";
|
||||
}
|
||||
|
||||
void RuntimeModel::GenerateTask(uint32_t device_id, uint64_t session_id,
|
||||
const std::shared_ptr<DavinciModel> &davinci_model) {
|
||||
MS_LOG(INFO) << "GenerateTask start.";
|
||||
MS_EXCEPTION_IF_NULL(davinci_model);
|
||||
auto task_infos = davinci_model->GetTaskInfoList();
|
||||
ModelContext model_context(device_id, session_id, davinci_model->GetPriority(), rt_model_handle_, rt_model_stream_,
|
||||
stream_list_, label_list_, event_list_);
|
||||
for (auto &task_info : task_infos) {
|
||||
auto task = TaskFactory::GetInstance().Create(model_context, task_info);
|
||||
task_list_.push_back(task);
|
||||
}
|
||||
MS_LOG(INFO) << "GenerateTask success.";
|
||||
}
|
||||
|
||||
void RuntimeModel::LoadComplete() {
|
||||
uint32_t task_id = 0;
|
||||
uint32_t stream_id = 0;
|
||||
auto rt_ret = rtModelGetTaskId(rt_model_handle_, &task_id, &stream_id);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rt api rtModelGetTaskId failed, ret: " << std::hex << rt_ret;
|
||||
}
|
||||
task_id_list_.push_back(task_id);
|
||||
stream_id_list_.push_back(stream_id);
|
||||
|
||||
rt_ret = rtModelLoadComplete(rt_model_handle_);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rt api rtModelLoadComplete failed, ret: " << std::hex << rt_ret;
|
||||
}
|
||||
}
|
||||
|
||||
void RuntimeModel::Load(uint32_t device_id, uint64_t session_id, const std::shared_ptr<DavinciModel> &davinci_model) {
|
||||
InitResource(davinci_model);
|
||||
GenerateTask(device_id, session_id, davinci_model);
|
||||
}
|
||||
|
||||
void RuntimeModel::DistributeTask() {
|
||||
MS_LOG(INFO) << "DistributeTask start.";
|
||||
for (auto &task : task_list_) {
|
||||
MS_EXCEPTION_IF_NULL(task);
|
||||
task->Distribute();
|
||||
|
||||
uint32_t task_id = 0;
|
||||
uint32_t stream_id = 0;
|
||||
rtError_t rt_ret = rtModelGetTaskId(rt_model_handle_, &task_id, &stream_id);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rt api rtModelGetTaskId failed, ret: " << std::hex << rt_ret;
|
||||
}
|
||||
task_id_list_.push_back(task_id);
|
||||
stream_id_list_.push_back(stream_id);
|
||||
if (task->Args() != nullptr) {
|
||||
std::shared_ptr<RuntimeInfo> runtime_tuple = std::make_shared<RuntimeInfo>(task_id, stream_id, task->Args());
|
||||
auto emplace_ret = runtime_info_map_.emplace(task->task_name(), runtime_tuple);
|
||||
if (!emplace_ret.second) {
|
||||
MS_LOG(WARNING) << "Task name exist: " << task->task_name();
|
||||
}
|
||||
}
|
||||
}
|
||||
if (task_list_.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Task list is empty";
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "DistributeTask success.";
|
||||
}
|
||||
|
||||
void RuntimeModel::Run() {
|
||||
MS_LOG(INFO) << "Davinci task run start.";
|
||||
rtError_t ret = rtModelExecute(rt_model_handle_, rt_model_stream_, 0);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rt api rtModelLoadComplete failed, ret: " << std::hex << ret;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Run rtModelExecute success, start to rtStreamSynchronize.";
|
||||
ret = rtStreamSynchronize(rt_model_stream_);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
if (ret == ACL_ERROR_RT_END_OF_SEQUENCE) {
|
||||
MS_LOG(INFO) << "Model stream ACL_ERROR_RT_END_OF_SEQUENCE signal received.";
|
||||
return;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Call rt api rtStreamSynchronize failed, ret: " << std::hex << ret;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Davinci task run success.";
|
||||
}
|
||||
|
||||
void RuntimeModel::RtModelUnbindStream() noexcept {
|
||||
for (size_t i = 0; i < stream_list_.size(); i++) {
|
||||
if (rtModelUnbindStream(rt_model_handle_, stream_list_[i]) != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Unbind stream from model failed! Index: " << i;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void RuntimeModel::RtStreamDestory() noexcept {
|
||||
if (rtStreamDestroy(rt_model_stream_) != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Destroy stream for rt_model failed!";
|
||||
return;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < stream_list_.size(); i++) {
|
||||
if (rtStreamDestroy(stream_list_[i]) != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Destroy stream failed! Index: " << i;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void RuntimeModel::RtLabelDestory() noexcept {
|
||||
for (size_t i = 0; i < label_list_.size(); i++) {
|
||||
if (label_list_[i] == nullptr) {
|
||||
continue;
|
||||
}
|
||||
if (rtLabelDestroy(label_list_[i]) != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Destroy label failed! Index: " << i;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void RuntimeModel::RtModelDestory() noexcept {
|
||||
rtError_t ret = rtModelDestroy(rt_model_handle_);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Call rt api rtModelDestroy failed, ret: " << std::hex << ret;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
void RuntimeModel::RtEventDestory() noexcept {
|
||||
for (size_t i = 0; i < event_list_.size(); i++) {
|
||||
if (rtEventDestroy(event_list_[i]) != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Destroy event failed! Index: " << i;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<uint32_t> &RuntimeModel::GetTaskIdList() const { return task_id_list_; }
|
||||
|
||||
const std::vector<uint32_t> &RuntimeModel::GetStreamIdList() const { return stream_id_list_; }
|
||||
} // namespace mindspore::ge::model_runner
|
|
@ -0,0 +1,71 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_RUNTIME_MODEL_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_RUNTIME_MODEL_H_
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <tuple>
|
||||
#include "runtime/base.h"
|
||||
#include "runtime/rt_model.h"
|
||||
#include "runtime/device/ascend/ge_runtime/davinci_model.h"
|
||||
|
||||
namespace mindspore::ge::model_runner {
|
||||
using RuntimeInfo = std::tuple<uint32_t, uint32_t, void *>;
|
||||
class Task;
|
||||
class RuntimeModel {
|
||||
public:
|
||||
RuntimeModel() = default;
|
||||
~RuntimeModel();
|
||||
|
||||
void Load(uint32_t device_id, uint64_t session_id, const std::shared_ptr<DavinciModel> &davinci_model);
|
||||
void DistributeTask();
|
||||
void LoadComplete();
|
||||
const std::vector<uint32_t> &GetTaskIdList() const;
|
||||
const std::vector<uint32_t> &GetStreamIdList() const;
|
||||
const std::map<std::string, std::shared_ptr<RuntimeInfo>> &GetRuntimeInfoMap() const { return runtime_info_map_; }
|
||||
rtModel_t GetModelHandle() const { return rt_model_handle_; }
|
||||
void Run();
|
||||
|
||||
private:
|
||||
void InitResource(const std::shared_ptr<DavinciModel> &davinci_model);
|
||||
void GenerateTask(uint32_t device_id, uint64_t session_id, const std::shared_ptr<DavinciModel> &davinci_model);
|
||||
void InitStream(const std::shared_ptr<DavinciModel> &davinci_model);
|
||||
void InitEvent(uint32_t event_num);
|
||||
void InitLabel(const std::shared_ptr<DavinciModel> &davinci_model);
|
||||
void RtModelUnbindStream() noexcept;
|
||||
void RtStreamDestory() noexcept;
|
||||
void RtModelDestory() noexcept;
|
||||
void RtLabelDestory() noexcept;
|
||||
void RtEventDestory() noexcept;
|
||||
|
||||
rtModel_t rt_model_handle_{};
|
||||
rtStream_t rt_model_stream_{};
|
||||
|
||||
std::vector<rtStream_t> stream_list_{};
|
||||
std::vector<rtLabel_t> label_list_{};
|
||||
std::vector<rtEvent_t> event_list_{};
|
||||
|
||||
std::vector<std::shared_ptr<Task>> task_list_{};
|
||||
|
||||
std::vector<uint32_t> task_id_list_{};
|
||||
std::vector<uint32_t> stream_id_list_{};
|
||||
std::map<std::string, std::shared_ptr<RuntimeInfo>> runtime_info_map_;
|
||||
};
|
||||
} // namespace mindspore::ge::model_runner
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_RUNTIME_MODEL_H_
|
|
@ -0,0 +1,168 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/device/ascend/ge_runtime/task/aicpu_task.h"
|
||||
#include <vector>
|
||||
#include "runtime/mem.h"
|
||||
#include "runtime/kernel.h"
|
||||
#include "runtime/device/ascend/ge_runtime/task/task_factory.h"
|
||||
#include "aicpu/common/aicpu_task_struct.h"
|
||||
|
||||
namespace mindspore::ge::model_runner {
|
||||
AicpuTask::AicpuTask(const ModelContext &model_context, const std::shared_ptr<AicpuTaskInfo> &task_info)
|
||||
: TaskRepeater<AicpuTaskInfo>(model_context, task_info),
|
||||
task_info_(task_info),
|
||||
stream_(nullptr),
|
||||
args_(nullptr),
|
||||
ext_info_(nullptr),
|
||||
input_output_addr_(nullptr) {
|
||||
MS_EXCEPTION_IF_NULL(task_info_);
|
||||
|
||||
auto stream_list = model_context.stream_list();
|
||||
if (stream_list.size() == 1) {
|
||||
stream_ = stream_list[0];
|
||||
} else if (stream_list.size() > task_info_->stream_id()) {
|
||||
stream_ = stream_list[task_info_->stream_id()];
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Index: " << task_info_->stream_id() << " >= stream_list.size(): " << stream_list.size();
|
||||
}
|
||||
}
|
||||
|
||||
AicpuTask::~AicpuTask() {
|
||||
ReleaseRtMem(&args_);
|
||||
ReleaseRtMem(&ext_info_);
|
||||
}
|
||||
|
||||
void AicpuTask::Distribute() {
|
||||
MS_LOG(INFO) << "InitAicpuTask start.";
|
||||
std::vector<void *> io_addrs;
|
||||
io_addrs.insert(io_addrs.end(), task_info_->input_data_addrs().begin(), task_info_->input_data_addrs().end());
|
||||
io_addrs.insert(io_addrs.end(), task_info_->output_data_addrs().begin(), task_info_->output_data_addrs().end());
|
||||
auto io_addrs_num = static_cast<uint32_t>(io_addrs.size());
|
||||
auto io_addrs_size = static_cast<uint32_t>(io_addrs_num * sizeof(void *));
|
||||
constexpr uint32_t io_addr_offset = sizeof(aicpu::AicpuParamHead);
|
||||
uint32_t node_def_len_offset = io_addr_offset + io_addrs_size;
|
||||
uint32_t node_def_addr_offset = node_def_len_offset + sizeof(uint32_t);
|
||||
uint32_t args_size = sizeof(aicpu::AicpuParamHead) + io_addrs_size +
|
||||
static_cast<uint32_t>(task_info_->node_def().size()) + sizeof(uint32_t);
|
||||
|
||||
// Malloc device memory for args
|
||||
rtError_t rt_ret = rtMalloc(&args_, args_size, RT_MEMORY_HBM);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rt api rtMalloc failed, ret: " << std::hex << rt_ret;
|
||||
}
|
||||
|
||||
SetAicpuParamHead(args_size, io_addrs_num);
|
||||
SetInputOutputAddrs(io_addrs, io_addr_offset);
|
||||
SetNodeDef(node_def_len_offset, node_def_addr_offset);
|
||||
|
||||
// for data dump
|
||||
input_output_addr_ = reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + io_addr_offset);
|
||||
auto dump_flag = task_info_->dump_flag() ? RT_KERNEL_DUMPFLAG : RT_KERNEL_DEFAULT;
|
||||
|
||||
MS_LOG(INFO) << "Distribute AicpuTask start, args_size = " << args_size << ", io_addrs_num =" << io_addrs_num
|
||||
<< ", so_name = " << task_info_->so_name() << ", kernel_name = " << task_info_->kernel_name()
|
||||
<< ", dump_flag = " << dump_flag;
|
||||
rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast<const void *>(task_info_->so_name().data()),
|
||||
reinterpret_cast<const void *>(task_info_->kernel_name().data()), 1, args_,
|
||||
args_size, nullptr, stream_, dump_flag);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rt api rtCpuKernelLaunchWithFlag failed, ret: " << std::hex << rt_ret;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Distribute AicpuTask end.";
|
||||
}
|
||||
|
||||
void AicpuTask::ReleaseRtMem(void **ptr) noexcept {
|
||||
if (ptr == nullptr || *ptr == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
rtError_t rt_ret = rtFree(*ptr);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
return;
|
||||
}
|
||||
*ptr = nullptr;
|
||||
}
|
||||
|
||||
void AicpuTask::SetAicpuParamHead(uint32_t args_size, uint32_t io_addrs_num) {
|
||||
aicpu::AicpuParamHead aicpu_param_head;
|
||||
aicpu_param_head.length = args_size;
|
||||
aicpu_param_head.ioAddrNum = io_addrs_num;
|
||||
|
||||
const auto &ext_info = task_info_->ext_info();
|
||||
uint32_t ext_size = ext_info.size();
|
||||
if (ext_info.empty()) {
|
||||
aicpu_param_head.extInfoLength = 0;
|
||||
aicpu_param_head.extInfoAddr = 0;
|
||||
} else {
|
||||
rtError_t flag = rtMalloc(&ext_info_, ext_size, RT_MEMORY_HBM);
|
||||
if (flag != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rt api rtMalloc failed, ret: " << std::hex << flag;
|
||||
}
|
||||
|
||||
flag = rtMemcpy(ext_info_, ext_size, const_cast<void *>(reinterpret_cast<const void *>(ext_info.data())), ext_size,
|
||||
RT_MEMCPY_HOST_TO_DEVICE);
|
||||
if (flag != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rt api rtMemcpy failed, ret: " << std::hex << flag;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "ext info size: " << ext_size;
|
||||
aicpu_param_head.extInfoLength = ext_size;
|
||||
aicpu_param_head.extInfoAddr = reinterpret_cast<uintptr_t>(ext_info_);
|
||||
}
|
||||
|
||||
// Memcpy AicpuParamHead
|
||||
auto rt_ret = rtMemcpy(args_, sizeof(aicpu::AicpuParamHead), reinterpret_cast<void *>(&aicpu_param_head),
|
||||
sizeof(aicpu::AicpuParamHead), RT_MEMCPY_HOST_TO_DEVICE);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rt api rtMemcpy failed, ret: " << std::hex << rt_ret;
|
||||
}
|
||||
}
|
||||
|
||||
void AicpuTask::SetInputOutputAddrs(const std::vector<void *> &io_addrs, uint32_t io_addr_offset) {
|
||||
// Memcpy io addrs
|
||||
if (!io_addrs.empty()) {
|
||||
auto rt_ret = rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + io_addr_offset),
|
||||
static_cast<uint32_t>(io_addrs.size()) * sizeof(void *), io_addrs.data(),
|
||||
static_cast<uint32_t>(io_addrs.size()) * sizeof(void *), RT_MEMCPY_HOST_TO_DEVICE);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rt api rtMemcpy failed, ret: " << std::hex << rt_ret;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AicpuTask::SetNodeDef(uint32_t node_def_len_offset, uint32_t node_def_addr_offset) {
|
||||
// Memcpy node def
|
||||
auto size = task_info_->node_def().size();
|
||||
auto rt_ret =
|
||||
rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + node_def_len_offset), sizeof(uint32_t),
|
||||
reinterpret_cast<const void *>(&size), sizeof(uint32_t), RT_MEMCPY_HOST_TO_DEVICE);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rt api rtMemcpy failed, ret: " << std::hex << rt_ret;
|
||||
}
|
||||
|
||||
// Memcpy node def
|
||||
rt_ret = rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + node_def_addr_offset),
|
||||
task_info_->node_def().size(), reinterpret_cast<const void *>(task_info_->node_def().data()),
|
||||
task_info_->node_def().size(), RT_MEMCPY_HOST_TO_DEVICE);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rt api rtMemcpy failed, ret: " << std::hex << rt_ret;
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_TASK(TaskInfoType::AICPU, AicpuTask, AicpuTaskInfo);
|
||||
} // namespace mindspore::ge::model_runner
|
|
@ -0,0 +1,51 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_AICPU_TASK_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_AICPU_TASK_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "runtime/device/ascend/ge_runtime/task/task.h"
|
||||
|
||||
namespace mindspore::ge::model_runner {
|
||||
class AicpuTask : public TaskRepeater<AicpuTaskInfo> {
|
||||
public:
|
||||
AicpuTask(const ModelContext &model_context, const std::shared_ptr<AicpuTaskInfo> &task_info);
|
||||
|
||||
~AicpuTask() override;
|
||||
|
||||
void Distribute() override;
|
||||
|
||||
void *Args() override { return input_output_addr_; }
|
||||
|
||||
std::string task_name() const override { return task_info_->op_name(); }
|
||||
|
||||
private:
|
||||
static void ReleaseRtMem(void **ptr) noexcept;
|
||||
void SetAicpuParamHead(uint32_t args_size, uint32_t io_addrs_num);
|
||||
void SetInputOutputAddrs(const std::vector<void *> &io_addrs, uint32_t io_addr_offset);
|
||||
void SetNodeDef(uint32_t node_def_len_offset, uint32_t node_def_addr_offset);
|
||||
|
||||
std::shared_ptr<AicpuTaskInfo> task_info_;
|
||||
void *stream_;
|
||||
void *args_;
|
||||
void *ext_info_;
|
||||
void *input_output_addr_;
|
||||
};
|
||||
} // namespace mindspore::ge::model_runner
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_AICPU_TASK_H_
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/device/ascend/ge_runtime/task/event_record_task.h"
|
||||
#include "runtime/kernel.h"
|
||||
#include "runtime/device/ascend/ge_runtime/task/task_factory.h"
|
||||
|
||||
namespace mindspore::ge::model_runner {
|
||||
EventRecordTask::EventRecordTask(const ModelContext &model_context,
|
||||
const std::shared_ptr<EventRecordTaskInfo> &task_info)
|
||||
: TaskRepeater<EventRecordTaskInfo>(model_context, task_info),
|
||||
task_info_(task_info),
|
||||
stream_(nullptr),
|
||||
event_(nullptr) {
|
||||
MS_EXCEPTION_IF_NULL(task_info_);
|
||||
auto stream_list = model_context.stream_list();
|
||||
auto event_list = model_context.event_list();
|
||||
uint32_t stream_id = task_info_->stream_id();
|
||||
uint32_t event_id = task_info_->event_id();
|
||||
if (stream_id >= stream_list.size() || event_id >= event_list.size()) {
|
||||
MS_LOG(EXCEPTION) << "stream_list size: " << stream_list.size() << ", stream_id: " << stream_id
|
||||
<< ", event_list size: " << event_list.size() << ", event_id: " << event_id;
|
||||
}
|
||||
stream_ = stream_list[stream_id];
|
||||
event_ = event_list[event_id];
|
||||
}
|
||||
|
||||
EventRecordTask::~EventRecordTask() {}
|
||||
|
||||
void EventRecordTask::Distribute() {
|
||||
MS_LOG(INFO) << "EventRecordTask Distribute start, stream: " << stream_ << ", event: " << event_
|
||||
<< ", stream_id: " << task_info_->stream_id() << ", event_id: " << task_info_->event_id();
|
||||
rtError_t rt_ret = rtEventRecord(event_, stream_);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rt api rtEventRecord failed, ret: " << std::hex << rt_ret;
|
||||
}
|
||||
MS_LOG(INFO) << "Distribute end.";
|
||||
}
|
||||
|
||||
REGISTER_TASK(TaskInfoType::EVENT_RECORD, EventRecordTask, EventRecordTaskInfo);
|
||||
} // namespace mindspore::ge::model_runner
|
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_EVENT_RECORD_TASK_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_EVENT_RECORD_TASK_H_
|
||||
|
||||
#include <memory>
|
||||
#include "runtime/device/ascend/ge_runtime/task/task.h"
|
||||
|
||||
namespace mindspore::ge::model_runner {
|
||||
class EventRecordTask : public TaskRepeater<EventRecordTaskInfo> {
|
||||
public:
|
||||
EventRecordTask(const ModelContext &model_context, const std::shared_ptr<EventRecordTaskInfo> &task_info);
|
||||
|
||||
~EventRecordTask() override;
|
||||
|
||||
void Distribute() override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<EventRecordTaskInfo> task_info_;
|
||||
rtStream_t stream_;
|
||||
rtEvent_t event_;
|
||||
};
|
||||
} // namespace mindspore::ge::model_runner
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_EVENT_RECORD_TASK_H_
|
|
@ -0,0 +1,59 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/device/ascend/ge_runtime/task/event_wait_task.h"
|
||||
#include "runtime/kernel.h"
|
||||
#include "runtime/device/ascend/ge_runtime/task/task_factory.h"
|
||||
|
||||
namespace mindspore::ge::model_runner {
|
||||
EventWaitTask::EventWaitTask(const ModelContext &model_context, const std::shared_ptr<EventWaitTaskInfo> &task_info)
|
||||
: TaskRepeater<EventWaitTaskInfo>(model_context, task_info),
|
||||
task_info_(task_info),
|
||||
stream_(nullptr),
|
||||
event_(nullptr) {
|
||||
MS_EXCEPTION_IF_NULL(task_info_);
|
||||
auto stream_list = model_context.stream_list();
|
||||
auto event_list = model_context.event_list();
|
||||
uint32_t stream_id = task_info_->stream_id();
|
||||
uint32_t event_id = task_info_->event_id();
|
||||
if (stream_id >= stream_list.size() || event_id >= event_list.size()) {
|
||||
MS_LOG(EXCEPTION) << "stream_list size: " << stream_list.size() << ", stream_id: " << stream_id
|
||||
<< ", event_list size: " << event_list.size() << ", event_id: " << event_id;
|
||||
}
|
||||
stream_ = stream_list[stream_id];
|
||||
event_ = event_list[event_id];
|
||||
}
|
||||
|
||||
EventWaitTask::~EventWaitTask() {}
|
||||
|
||||
void EventWaitTask::Distribute() {
|
||||
MS_LOG(INFO) << "EventWaitTask Distribute start, stream: " << stream_ << ", event: " << event_
|
||||
<< ", stream_id: " << task_info_->stream_id() << ", event_id: " << task_info_->event_id();
|
||||
|
||||
rtError_t rt_ret = rtStreamWaitEvent(stream_, event_);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rt api rtStreamWaitEvent failed, ret: " << std::hex << rt_ret;
|
||||
}
|
||||
|
||||
rt_ret = rtEventReset(event_, stream_);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rt api rtEventReset failed, ret: " << std::hex << rt_ret;
|
||||
}
|
||||
MS_LOG(INFO) << "Distribute end.";
|
||||
}
|
||||
|
||||
REGISTER_TASK(TaskInfoType::EVENT_WAIT, EventWaitTask, EventWaitTaskInfo);
|
||||
} // namespace mindspore::ge::model_runner
|
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_EVENT_WAIT_TASK_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_EVENT_WAIT_TASK_H_
|
||||
|
||||
#include <memory>
|
||||
#include "runtime/device/ascend/ge_runtime/task/task.h"
|
||||
|
||||
namespace mindspore::ge::model_runner {
|
||||
class EventWaitTask : public TaskRepeater<EventWaitTaskInfo> {
|
||||
public:
|
||||
EventWaitTask(const ModelContext &model_context, const std::shared_ptr<EventWaitTaskInfo> &task_info);
|
||||
|
||||
~EventWaitTask() override;
|
||||
|
||||
void Distribute() override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<EventWaitTaskInfo> task_info_;
|
||||
rtStream_t stream_;
|
||||
rtEvent_t event_;
|
||||
};
|
||||
} // namespace mindspore::ge::model_runner
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_EVENT_WAIT_TASK_H_
|
|
@ -0,0 +1,221 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/device/ascend/ge_runtime/task/hccl_task.h"
|
||||
#include <algorithm>
|
||||
#include "runtime/device/ascend/ge_runtime/task/task_factory.h"
|
||||
#include "common/opskernel/ops_kernel_info_store.h"
|
||||
#include "common/opskernel/ge_task_info.h"
|
||||
|
||||
namespace mindspore::ge::model_runner {
|
||||
std::map<rtModel_t, std::map<uint32_t, std::vector<std::weak_ptr<HcclTask::StreamGuard>>>>
|
||||
HcclTask::model_stream_mapping_;
|
||||
std::mutex HcclTask::model_stream_mapping_mutex_;
|
||||
|
||||
HcclTask::HcclTask(const ModelContext &model_context, const std::shared_ptr<HcclTaskInfo> &task_info)
|
||||
: TaskRepeater<HcclTaskInfo>(model_context, task_info),
|
||||
task_info_(task_info),
|
||||
stream_(nullptr),
|
||||
workspace_mem_(nullptr),
|
||||
rt_model_handle_(nullptr),
|
||||
priority_(0),
|
||||
secondary_stream_list_() {
|
||||
MS_EXCEPTION_IF_NULL(task_info_);
|
||||
|
||||
priority_ = model_context.priority();
|
||||
rt_model_handle_ = model_context.rt_model_handle();
|
||||
auto stream_list = model_context.stream_list();
|
||||
|
||||
if (stream_list.size() == 1) {
|
||||
stream_ = stream_list[0];
|
||||
} else if (stream_list.size() > task_info_->stream_id()) {
|
||||
stream_ = stream_list[task_info_->stream_id()];
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Index: " << task_info_->stream_id() << " >= stream_list.size(): " << stream_list.size();
|
||||
}
|
||||
}
|
||||
|
||||
HcclTask::~HcclTask() {}
|
||||
|
||||
void HcclTask::Distribute() {
|
||||
// Ops kernel info store
|
||||
// Get privateDef and opsKernelStorePtr
|
||||
MS_LOG(INFO) << "Distribute hccl task start.";
|
||||
void *ops_kernel_store = task_info_->ops_kernel_store();
|
||||
::ge::OpsKernelInfoStore *ops_kernel_info_store = reinterpret_cast<::ge::OpsKernelInfoStore *>(ops_kernel_store);
|
||||
MS_EXCEPTION_IF_NULL(ops_kernel_info_store);
|
||||
|
||||
char *private_def = reinterpret_cast<char *>(const_cast<char unsigned *>(task_info_->private_def().data()));
|
||||
auto private_def_len = static_cast<uint32_t>(task_info_->private_def().size());
|
||||
MS_LOG(INFO) << "The first address of the custom info, privateDef= " << private_def;
|
||||
SetSecondaryStream();
|
||||
|
||||
if (task_info_->workspace_size() > 0) {
|
||||
workspace_mem_ = task_info_->workspace_addr();
|
||||
}
|
||||
|
||||
::ge::GETaskInfo ge_task;
|
||||
ge_task.id = 0;
|
||||
ge_task.type = static_cast<uint16_t>(RT_MODEL_TASK_HCCL);
|
||||
ge_task.stream = stream_;
|
||||
|
||||
ge_task.kernelHcclInfo = std::vector<::ge::GETaskKernelHcclInfo>(1);
|
||||
ge_task.kernelHcclInfo[0].hccl_type = task_info_->hccl_type();
|
||||
ge_task.kernelHcclInfo[0].inputDataAddr = task_info_->input_data_addr();
|
||||
ge_task.kernelHcclInfo[0].outputDataAddr = task_info_->output_data_addr();
|
||||
ge_task.kernelHcclInfo[0].workSpaceAddr = workspace_mem_;
|
||||
ge_task.kernelHcclInfo[0].workSpaceMemSize = task_info_->workspace_size();
|
||||
ge_task.kernelHcclInfo[0].count = task_info_->count();
|
||||
ge_task.kernelHcclInfo[0].dataType = static_cast<int32_t>(task_info_->data_type());
|
||||
ge_task.kernelHcclInfo[0].opType = static_cast<int32_t>(task_info_->op_type());
|
||||
ge_task.kernelHcclInfo[0].rootId = task_info_->root_id();
|
||||
|
||||
std::vector<rtStream_t> secondary_stream_list;
|
||||
std::transform(secondary_stream_list_.begin(), secondary_stream_list_.end(),
|
||||
std::back_inserter(secondary_stream_list),
|
||||
[](const std::shared_ptr<StreamGuard> &stream) -> rtStream_t { return stream->GetStream(); });
|
||||
ge_task.kernelHcclInfo[0].hcclStreamList = secondary_stream_list;
|
||||
|
||||
ge_task.privateDef = private_def;
|
||||
ge_task.privateDefLen = private_def_len;
|
||||
ge_task.opsKernelStorePtr = ops_kernel_store;
|
||||
|
||||
MS_LOG(INFO) << "Begin to call function LoadTask in hccl.";
|
||||
auto result = ops_kernel_info_store->LoadTask(ge_task);
|
||||
// tagHcclResult::HCCL_SUCCESS is 0
|
||||
if (result != 0) {
|
||||
MS_LOG(EXCEPTION) << "davinci_model : load task fail, return ret: " << result;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Call function LoadTask end.";
|
||||
}
|
||||
|
||||
void HcclTask::SetSecondaryStream() {
|
||||
const uint32_t master_stream_id = task_info_->stream_id();
|
||||
const int64_t hccl_secondary_stream_num = task_info_->hccl_stream_num();
|
||||
std::lock_guard<std::mutex> lock(model_stream_mapping_mutex_);
|
||||
|
||||
// no model, create all secondary stream
|
||||
auto model_iter = model_stream_mapping_.find(rt_model_handle_);
|
||||
if (model_iter == model_stream_mapping_.end()) {
|
||||
MS_LOG(INFO) << "Need to create map for rt_model_handle_: " << rt_model_handle_ << " with new mainstream "
|
||||
<< master_stream_id;
|
||||
CreateStream(hccl_secondary_stream_num, master_stream_id);
|
||||
MS_LOG(INFO) << "Initialize hccl secondary stream success, hccl_secondary_stream_num=" << hccl_secondary_stream_num;
|
||||
return;
|
||||
}
|
||||
|
||||
// has model, but no secondary stream before, create all secondary stream
|
||||
auto &master_secondary_stream_map = model_iter->second;
|
||||
auto iter = master_secondary_stream_map.find(master_stream_id);
|
||||
if (iter == master_secondary_stream_map.end()) {
|
||||
MS_LOG(INFO) << "Need to create secondary stream for " << task_info_->op_name() << " with new mainstream "
|
||||
<< master_stream_id;
|
||||
CreateStream(hccl_secondary_stream_num, master_stream_id);
|
||||
MS_LOG(INFO) << "Initialize hccl secondary stream success, hccl_secondary_stream_num=" << hccl_secondary_stream_num;
|
||||
return;
|
||||
}
|
||||
|
||||
// has model, has secondary stream, but number is not enough to be reuse
|
||||
std::vector<std::weak_ptr<StreamGuard>> &secondary_stream_vec = iter->second;
|
||||
if (static_cast<size_t>(hccl_secondary_stream_num) > secondary_stream_vec.size()) {
|
||||
size_t created_stream_num = secondary_stream_vec.size();
|
||||
auto need_to_create_num = hccl_secondary_stream_num - created_stream_num;
|
||||
MS_LOG(INFO) << "Need to reuse " << secondary_stream_vec.size() << " secondary stream and create "
|
||||
<< need_to_create_num << " new secondary stream.";
|
||||
for (size_t i = 0; i < secondary_stream_vec.size(); ++i) {
|
||||
secondary_stream_list_.push_back(GetSecondaryStream(&secondary_stream_vec, i));
|
||||
}
|
||||
CreateStream(need_to_create_num, master_stream_id);
|
||||
MS_LOG(INFO) << "Initialize hccl secondary stream success, hccl_secondary_stream_num=" << hccl_secondary_stream_num;
|
||||
return;
|
||||
}
|
||||
|
||||
// all can be reuse
|
||||
MS_LOG(INFO) << "Number of secondary stream " << hccl_secondary_stream_num << " is enough to be reused.";
|
||||
for (int64_t i = 0; i < hccl_secondary_stream_num; ++i) {
|
||||
secondary_stream_list_.push_back(GetSecondaryStream(&secondary_stream_vec, i));
|
||||
}
|
||||
MS_LOG(INFO) << "Initialize hccl secondary stream success, hccl_secondary_stream_num = " << hccl_secondary_stream_num;
|
||||
}
|
||||
|
||||
void HcclTask::CreateStream(int64_t stream_num, int64_t master_stream_id) {
|
||||
MS_LOG(INFO) << "Start to create " << stream_num << " hccl secondary stream.";
|
||||
for (int64_t i = 0; i < stream_num; ++i) {
|
||||
rtStream_t stream = nullptr;
|
||||
CreateStream(rt_model_handle_, &stream);
|
||||
auto shared_stream = std::make_shared<StreamGuard>(rt_model_handle_, stream);
|
||||
SaveHcclSecondaryStream(master_stream_id, shared_stream);
|
||||
secondary_stream_list_.push_back(shared_stream);
|
||||
}
|
||||
MS_LOG(INFO) << "CreateStream success.";
|
||||
}
|
||||
|
||||
void HcclTask::CreateStream(rtModel_t model, rtStream_t *stream) const {
|
||||
MS_EXCEPTION_IF_NULL(stream);
|
||||
|
||||
rtError_t rt_ret = rtStreamCreateWithFlags(stream, priority_, RT_STREAM_PERSISTENT | RT_STREAM_FORCE_COPY);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rt api rtEventRecord failed, ret: " << std::hex << rt_ret;
|
||||
}
|
||||
// Create secondary stream, inactive by default, activated by hccl
|
||||
rt_ret = rtModelBindStream(model, *stream, RT_MODEL_WAIT_ACTIVE_STREAM);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rt api rtEventRecord failed, ret: " << std::hex << rt_ret;
|
||||
}
|
||||
}
|
||||
|
||||
void HcclTask::SaveHcclSecondaryStream(int64_t master_stream_id, const std::shared_ptr<StreamGuard> &stream) {
|
||||
if (model_stream_mapping_.find(rt_model_handle_) == model_stream_mapping_.end()) {
|
||||
model_stream_mapping_.emplace(rt_model_handle_, std::map<uint32_t, std::vector<std::weak_ptr<StreamGuard>>>());
|
||||
}
|
||||
std::map<uint32_t, std::vector<std::weak_ptr<StreamGuard>>> &master_secondary_stream_map =
|
||||
model_stream_mapping_.at(rt_model_handle_);
|
||||
master_secondary_stream_map[master_stream_id].emplace_back(stream);
|
||||
}
|
||||
|
||||
HcclTask::StreamGuard::~StreamGuard() {
|
||||
rtError_t rt_ret = rtModelUnbindStream(model_, stream_);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Call rt api rtModelUnbindStream failed, ret: " << std::hex << rt_ret;
|
||||
return;
|
||||
}
|
||||
|
||||
rt_ret = rtStreamDestroy(stream_);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Call rt api rtStreamDestroy failed, ret: " << std::hex << rt_ret;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<HcclTask::StreamGuard> HcclTask::GetSecondaryStream(
|
||||
std::vector<std::weak_ptr<StreamGuard>> *secondary_streams, size_t index) {
|
||||
MS_EXCEPTION_IF_NULL(secondary_streams);
|
||||
if (index >= secondary_streams->size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid stream index " << index << ", secondary streams size " << secondary_streams->size();
|
||||
}
|
||||
auto stream = secondary_streams->at(index).lock();
|
||||
if (stream == nullptr) {
|
||||
rtStream_t new_stream = nullptr;
|
||||
CreateStream(rt_model_handle_, &new_stream);
|
||||
stream = std::make_shared<HcclTask::StreamGuard>(rt_model_handle_, new_stream);
|
||||
(*secondary_streams)[index] = stream;
|
||||
}
|
||||
return stream;
|
||||
}
|
||||
|
||||
REGISTER_TASK(TaskInfoType::HCCL, HcclTask, HcclTaskInfo);
|
||||
} // namespace mindspore::ge::model_runner
|
|
@ -0,0 +1,68 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_HCCL_TASK_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_HCCL_TASK_H_
|
||||
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <mutex>
|
||||
#include "runtime/device/ascend/ge_runtime/task/task.h"
|
||||
|
||||
namespace mindspore::ge::model_runner {
|
||||
class HcclTask : public TaskRepeater<HcclTaskInfo> {
|
||||
public:
|
||||
HcclTask(const ModelContext &model_context, const std::shared_ptr<HcclTaskInfo> &task_info);
|
||||
|
||||
~HcclTask() override;
|
||||
|
||||
void Distribute() override;
|
||||
|
||||
private:
|
||||
class StreamGuard;
|
||||
void SetSecondaryStream();
|
||||
void CreateStream(int64_t stream_num, int64_t master_stream_id);
|
||||
void CreateStream(rtModel_t model, rtStream_t *stream) const;
|
||||
void SaveHcclSecondaryStream(int64_t master_stream_id, const std::shared_ptr<StreamGuard> &stream);
|
||||
std::shared_ptr<StreamGuard> GetSecondaryStream(std::vector<std::weak_ptr<StreamGuard>> *secondary_streams,
|
||||
size_t index);
|
||||
|
||||
std::shared_ptr<HcclTaskInfo> task_info_;
|
||||
void *stream_;
|
||||
void *workspace_mem_;
|
||||
rtModel_t rt_model_handle_;
|
||||
int32_t priority_;
|
||||
std::vector<std::shared_ptr<StreamGuard>> secondary_stream_list_;
|
||||
|
||||
// map<key: model pointer, value: map<key: primary stream id, value: vector<secondary stream pointer>>>
|
||||
static std::map<rtModel_t, std::map<uint32_t, std::vector<std::weak_ptr<StreamGuard>>>> model_stream_mapping_;
|
||||
static std::mutex model_stream_mapping_mutex_;
|
||||
};
|
||||
|
||||
class HcclTask::StreamGuard {
|
||||
public:
|
||||
StreamGuard(rtModel_t model, rtStream_t stream) : model_(model), stream_(stream) {}
|
||||
~StreamGuard();
|
||||
rtStream_t GetStream() const { return stream_; }
|
||||
|
||||
private:
|
||||
rtModel_t model_;
|
||||
rtStream_t stream_;
|
||||
};
|
||||
} // namespace mindspore::ge::model_runner
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_HCCL_TASK_H_
|
|
@ -0,0 +1,83 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/device/ascend/ge_runtime/task/label_goto_task.h"
|
||||
#include "runtime/mem.h"
|
||||
#include "runtime/device/ascend/ge_runtime/task/task_factory.h"
|
||||
|
||||
namespace mindspore::ge::model_runner {
|
||||
LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info)
|
||||
: TaskRepeater<LabelGotoTaskInfo>(model_context, task_info),
|
||||
task_info_(task_info),
|
||||
stream_(nullptr),
|
||||
index_value_(nullptr) {
|
||||
MS_EXCEPTION_IF_NULL(task_info_);
|
||||
auto stream_list = model_context.stream_list();
|
||||
auto label_list = model_context.label_list();
|
||||
rt_model_handle_ = model_context.rt_model_handle();
|
||||
uint32_t stream_id = task_info_->stream_id();
|
||||
label_id_ = task_info_->label_id();
|
||||
MS_LOG(INFO) << "Stream list size: " << stream_list.size() << ", stream id: " << stream_id;
|
||||
MS_LOG(INFO) << "Label list size: " << label_list.size() << ", label id: " << label_id_;
|
||||
if (stream_id >= stream_list.size() || label_id_ >= label_list.size()) {
|
||||
MS_LOG(EXCEPTION) << "Stream/Label id invalid.";
|
||||
}
|
||||
stream_ = stream_list[stream_id];
|
||||
label_manager_ = LabelManager::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(label_manager_);
|
||||
label_info_ = label_manager_->GetLabelInfo(rt_model_handle_, {label_id_}, label_list);
|
||||
MS_EXCEPTION_IF_NULL(label_info_);
|
||||
}
|
||||
|
||||
LabelGotoTask::~LabelGotoTask() {
|
||||
if (index_value_ != nullptr) {
|
||||
rtError_t rt_ret = rtFree(index_value_);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Call rtFree index_value_ failed, ret: " << std::hex << rt_ret;
|
||||
}
|
||||
index_value_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void LabelGotoTask::Distribute() {
|
||||
MS_LOG(INFO) << "LabelGotoTask Distribute start.";
|
||||
MS_EXCEPTION_IF_NULL(stream_);
|
||||
MS_EXCEPTION_IF_NULL(label_info_);
|
||||
|
||||
if (index_value_ == nullptr) {
|
||||
rtError_t rt_ret = rtMalloc(&index_value_, sizeof(uint64_t), RT_MEMORY_HBM);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rt api rtMalloc failed, ret: " << std::hex << rt_ret;
|
||||
}
|
||||
|
||||
uint64_t index = 0;
|
||||
rt_ret = rtMemcpy(index_value_, sizeof(uint64_t), &index, sizeof(index), RT_MEMCPY_HOST_TO_DEVICE);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rt api rtMemcpy failed, ret: " << std::hex << rt_ret;
|
||||
}
|
||||
}
|
||||
|
||||
void *label_info = label_info_->GetLabelInfo();
|
||||
rtError_t rt_ret = rtLabelSwitchByIndex(index_value_, 1, label_info, stream_);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rt api rtLabelSwitchByIndex failed, ret: " << std::hex << rt_ret;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "DistributeTask end.";
|
||||
}
|
||||
|
||||
REGISTER_TASK(TaskInfoType::LABEL_GOTO, LabelGotoTask, LabelGotoTaskInfo);
|
||||
} // namespace mindspore::ge::model_runner
|
|
@ -0,0 +1,46 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include "runtime/device/ascend/ge_runtime/task/task.h"
|
||||
#include "runtime/device/ascend/ge_runtime/task/label_manager.h"
|
||||
|
||||
namespace mindspore::ge::model_runner {
|
||||
class LabelGotoTask : public TaskRepeater<LabelGotoTaskInfo> {
|
||||
public:
|
||||
LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info);
|
||||
|
||||
~LabelGotoTask() override;
|
||||
|
||||
void Distribute() override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<LabelGotoTaskInfo> task_info_;
|
||||
void *stream_;
|
||||
std::shared_ptr<LabelGuard> label_info_;
|
||||
void *index_value_;
|
||||
uint32_t label_id_;
|
||||
rtModel_t rt_model_handle_;
|
||||
std::shared_ptr<LabelManager> label_manager_;
|
||||
};
|
||||
} // namespace mindspore::ge::model_runner
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_
|
|
@ -0,0 +1,116 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "runtime/device/ascend/ge_runtime/task/label_manager.h"
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include "runtime/mem.h"
|
||||
#include "runtime/rt_model.h"
|
||||
#include "mindspore/core/utils/log_adapter.h"
|
||||
|
||||
namespace mindspore::ge::model_runner {
|
||||
std::weak_ptr<LabelManager> LabelManager::instance_;
|
||||
std::mutex LabelManager::instance_mutex_;
|
||||
|
||||
template <class T>
|
||||
static std::string GetVectorString(const std::vector<T> &vec) {
|
||||
std::string ret;
|
||||
for (size_t i = 0; i < vec.size(); ++i) {
|
||||
if (i != 0) {
|
||||
ret.push_back(',');
|
||||
}
|
||||
ret += std::to_string(vec[i]);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
LabelGuard::~LabelGuard() {
|
||||
void *label_info = GetLabelInfo();
|
||||
if (label_info != nullptr) {
|
||||
rtError_t rt_ret = rtFree(label_info);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "rtFree label_info failed! ret: " << std::hex << rt_ret;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<LabelManager> LabelManager::GetInstance() {
|
||||
std::lock_guard<std::mutex> lock(instance_mutex_);
|
||||
auto instance = instance_.lock();
|
||||
if (instance != nullptr) {
|
||||
return instance;
|
||||
}
|
||||
|
||||
instance = std::make_shared<LabelManager>();
|
||||
instance_ = instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
std::shared_ptr<LabelGuard> LabelManager::GetLabelInfo(rtModel_t model, const std::vector<uint32_t> &label_ids,
|
||||
const std::vector<void *> &all_label) {
|
||||
std::lock_guard<std::mutex> lock(model_info_mapping_mutex_);
|
||||
rtError_t rt_ret;
|
||||
auto model_iter = model_info_mapping_.find(model);
|
||||
if (model_iter == model_info_mapping_.end()) {
|
||||
model_info_mapping_.emplace(model, std::map<std::string, std::weak_ptr<LabelGuard>>());
|
||||
model_iter = model_info_mapping_.find(model);
|
||||
}
|
||||
|
||||
std::string label_id_str = GetVectorString(label_ids);
|
||||
auto &label_map = model_iter->second;
|
||||
auto label_iter = label_map.find(label_id_str);
|
||||
if (label_iter != label_map.end()) {
|
||||
auto label_guard = label_iter->second.lock();
|
||||
if (label_guard != nullptr) {
|
||||
MS_LOG(INFO) << "model " << model << " find same label id " << label_id_str;
|
||||
return label_guard;
|
||||
}
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Alloc label id " << label_id_str << " for model " << model;
|
||||
void *label_info = nullptr;
|
||||
std::vector<void *> label_list;
|
||||
bool status = true;
|
||||
std::transform(label_ids.begin(), label_ids.end(), std::back_inserter(label_list),
|
||||
[&all_label, &status](uint32_t idx) -> void * {
|
||||
if (idx >= all_label.size()) {
|
||||
MS_LOG(ERROR) << "Invalid label id " << idx << " all label list size " << all_label.size();
|
||||
status = false;
|
||||
return nullptr;
|
||||
}
|
||||
return all_label[idx];
|
||||
});
|
||||
if (!status) {
|
||||
MS_LOG(ERROR) << "Get label info failed.";
|
||||
return nullptr;
|
||||
}
|
||||
uint32_t label_info_size = sizeof(rtLabelDevInfo) * label_list.size();
|
||||
rt_ret = rtMalloc(&label_info, label_info_size, RT_MEMORY_HBM);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Call rt api rtMalloc failed, ret: " << std::hex << rt_ret;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info, label_info_size);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Call rt api rtLabelListCpy failed, ret: " << std::hex << rt_ret;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto label_guard = std::make_shared<LabelGuard>(label_info);
|
||||
label_map.emplace(label_id_str, label_guard);
|
||||
return label_guard;
|
||||
}
|
||||
} // namespace mindspore::ge::model_runner
|
|
@ -0,0 +1,51 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_MANAGER_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_MANAGER_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include "runtime/base.h"
|
||||
|
||||
namespace mindspore::ge::model_runner {
|
||||
class LabelGuard {
|
||||
public:
|
||||
explicit LabelGuard(void *label_info) : label_info_(reinterpret_cast<uintptr_t>(label_info)) {}
|
||||
~LabelGuard();
|
||||
void *GetLabelInfo() { return reinterpret_cast<void *>(label_info_); }
|
||||
|
||||
private:
|
||||
uintptr_t label_info_;
|
||||
};
|
||||
|
||||
class LabelManager {
|
||||
public:
|
||||
static std::shared_ptr<LabelManager> GetInstance();
|
||||
std::shared_ptr<LabelGuard> GetLabelInfo(rtModel_t model, const std::vector<uint32_t> &label_ids,
|
||||
const std::vector<void *> &all_label);
|
||||
|
||||
private:
|
||||
std::mutex model_info_mapping_mutex_;
|
||||
std::map<rtModel_t, std::map<std::string, std::weak_ptr<LabelGuard>>> model_info_mapping_;
|
||||
|
||||
static std::weak_ptr<LabelManager> instance_;
|
||||
static std::mutex instance_mutex_;
|
||||
};
|
||||
} // namespace mindspore::ge::model_runner
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_MANAGER_H_
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/device/ascend/ge_runtime/task/label_set_task.h"
|
||||
#include "runtime/device/ascend/ge_runtime/task/task_factory.h"
|
||||
|
||||
namespace mindspore::ge::model_runner {
|
||||
LabelSetTask::LabelSetTask(const ModelContext &model_context, const std::shared_ptr<LabelSetTaskInfo> &task_info)
|
||||
: TaskRepeater<LabelSetTaskInfo>(model_context, task_info),
|
||||
task_info_(task_info),
|
||||
stream_(nullptr),
|
||||
label_(nullptr) {
|
||||
MS_EXCEPTION_IF_NULL(task_info_);
|
||||
auto stream_list = model_context.stream_list();
|
||||
auto label_list = model_context.label_list();
|
||||
uint32_t stream_id = task_info->stream_id();
|
||||
uint32_t label_id = task_info->label_id();
|
||||
MS_LOG(INFO) << "Stream list size: " << stream_list.size() << ", stream id: " << stream_id;
|
||||
MS_LOG(INFO) << "Label list size: " << label_list.size() << ", label id: " << label_id;
|
||||
if (stream_id >= stream_list.size() || label_id >= label_list.size()) {
|
||||
MS_LOG(EXCEPTION) << "Stream/Label id invalid.";
|
||||
}
|
||||
stream_ = stream_list[stream_id];
|
||||
label_ = label_list[label_id];
|
||||
}
|
||||
|
||||
LabelSetTask::~LabelSetTask() {}
|
||||
|
||||
void LabelSetTask::Distribute() {
|
||||
MS_LOG(INFO) << "LabelSetTask Distribute start.";
|
||||
MS_EXCEPTION_IF_NULL(stream_);
|
||||
MS_EXCEPTION_IF_NULL(label_);
|
||||
|
||||
rtError_t rt_ret = rtLabelSet(label_, stream_);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rt api rtLabelSet failed, ret: " << std::hex << rt_ret;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "DistributeTask end.";
|
||||
}
|
||||
|
||||
REGISTER_TASK(TaskInfoType::LABEL_SET, LabelSetTask, LabelSetTaskInfo);
|
||||
} // namespace mindspore::ge::model_runner
|
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_SET_TASK_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_SET_TASK_H_
|
||||
|
||||
#include <memory>
|
||||
#include "runtime/device/ascend/ge_runtime/task/task.h"
|
||||
|
||||
namespace mindspore::ge::model_runner {
|
||||
class LabelSetTask : public TaskRepeater<LabelSetTaskInfo> {
|
||||
public:
|
||||
LabelSetTask(const ModelContext &model_context, const std::shared_ptr<LabelSetTaskInfo> &task_info);
|
||||
|
||||
~LabelSetTask() override;
|
||||
|
||||
void Distribute() override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<LabelSetTaskInfo> task_info_;
|
||||
void *stream_;
|
||||
void *label_;
|
||||
};
|
||||
} // namespace mindspore::ge::model_runner
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_SET_TASK_H_
|
|
@ -0,0 +1,77 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/device/ascend/ge_runtime/task/label_switch_task.h"
|
||||
#include "runtime/device/ascend/ge_runtime/task/task_factory.h"
|
||||
|
||||
namespace mindspore::ge::model_runner {
|
||||
LabelSwitchTask::LabelSwitchTask(const ModelContext &model_context,
|
||||
const std::shared_ptr<LabelSwitchTaskInfo> &task_info)
|
||||
: TaskRepeater<LabelSwitchTaskInfo>(model_context, task_info),
|
||||
task_info_(task_info),
|
||||
stream_(nullptr),
|
||||
label_info_(nullptr) {
|
||||
MS_EXCEPTION_IF_NULL(task_info);
|
||||
|
||||
rt_model_handle_ = model_context.rt_model_handle();
|
||||
auto all_label_resource = model_context.label_list();
|
||||
auto stream_list = model_context.stream_list();
|
||||
uint32_t stream_id = task_info->stream_id();
|
||||
MS_LOG(INFO) << "Stream list size: " << stream_list.size() << ", stream id: " << stream_id;
|
||||
if (stream_id >= stream_list.size()) {
|
||||
MS_LOG(EXCEPTION) << "Stream id invalid.";
|
||||
}
|
||||
stream_ = stream_list[stream_id];
|
||||
label_manager_ = LabelManager::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(label_manager_);
|
||||
label_info_ = label_manager_->GetLabelInfo(rt_model_handle_, task_info_->label_list(), all_label_resource);
|
||||
MS_EXCEPTION_IF_NULL(label_info_);
|
||||
}
|
||||
|
||||
LabelSwitchTask::~LabelSwitchTask() {}
|
||||
|
||||
void LabelSwitchTask::Distribute() {
|
||||
MS_LOG(INFO) << "LabelSwitchTask Distribute start.";
|
||||
CheckParamValid();
|
||||
|
||||
void *label_info = label_info_->GetLabelInfo();
|
||||
rtError_t rt_ret = rtLabelSwitchByIndex(task_info_->cond(), task_info_->label_size(), label_info, stream_);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rt api rtLabelSwitchByIndex failed, ret: " << std::hex << rt_ret;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "DistributeTask end.";
|
||||
}
|
||||
|
||||
void LabelSwitchTask::CheckParamValid() {
|
||||
MS_EXCEPTION_IF_NULL(stream_);
|
||||
|
||||
if (task_info_->label_list().empty()) {
|
||||
MS_LOG(EXCEPTION) << "label_list is empty.";
|
||||
}
|
||||
|
||||
if (task_info_->label_size() != task_info_->label_list().size()) {
|
||||
MS_LOG(EXCEPTION) << "label_list size " << task_info_->label_list().size() << " but label_size is "
|
||||
<< task_info_->label_size();
|
||||
}
|
||||
|
||||
if (task_info_->label_size() >= UINT32_MAX / sizeof(rtLabelDevInfo)) {
|
||||
MS_LOG(EXCEPTION) << "label_size " << task_info_->label_size() << " will overflow.";
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_TASK(TaskInfoType::LABEL_SWITCH, LabelSwitchTask, LabelSwitchTaskInfo);
|
||||
} // namespace mindspore::ge::model_runner
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_
|
||||
|
||||
#include <memory>
|
||||
#include "runtime/device/ascend/ge_runtime/task/task.h"
|
||||
#include "runtime/device/ascend/ge_runtime/task/label_manager.h"
|
||||
|
||||
namespace mindspore::ge::model_runner {
|
||||
class LabelSwitchTask : public TaskRepeater<LabelSwitchTaskInfo> {
|
||||
public:
|
||||
LabelSwitchTask(const ModelContext &model_context, const std::shared_ptr<LabelSwitchTaskInfo> &task_info);
|
||||
|
||||
~LabelSwitchTask() override;
|
||||
|
||||
void Distribute() override;
|
||||
|
||||
private:
|
||||
void CheckParamValid();
|
||||
|
||||
std::shared_ptr<LabelSwitchTaskInfo> task_info_;
|
||||
void *stream_;
|
||||
rtModel_t rt_model_handle_;
|
||||
std::shared_ptr<LabelGuard> label_info_;
|
||||
std::shared_ptr<LabelManager> label_manager_;
|
||||
};
|
||||
} // namespace mindspore::ge::model_runner
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_
|
|
@ -0,0 +1,51 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/device/ascend/ge_runtime/task/memcpy_async_task.h"
|
||||
#include "runtime/mem.h"
|
||||
#include "runtime/device/ascend/ge_runtime/task/task_factory.h"
|
||||
|
||||
namespace mindspore::ge::model_runner {
|
||||
MemcpyAsyncTask::MemcpyAsyncTask(const ModelContext &model_context,
|
||||
const std::shared_ptr<MemcpyAsyncTaskInfo> &task_info)
|
||||
: TaskRepeater<MemcpyAsyncTaskInfo>(model_context, task_info), task_info_(task_info), stream_(nullptr) {
|
||||
MS_EXCEPTION_IF_NULL(task_info);
|
||||
auto stream_list = model_context.stream_list();
|
||||
uint32_t stream_id = task_info->stream_id();
|
||||
|
||||
MS_LOG(INFO) << "Stream list size: " << stream_list.size() << ", stream id: " << stream_id;
|
||||
if (stream_id >= stream_list.size()) {
|
||||
MS_LOG(EXCEPTION) << "Index: " << task_info->stream_id() << " >= stream_list.size(): " << stream_list.size();
|
||||
}
|
||||
stream_ = stream_list[stream_id];
|
||||
}
|
||||
|
||||
MemcpyAsyncTask::~MemcpyAsyncTask() {}
|
||||
|
||||
void MemcpyAsyncTask::Distribute() {
|
||||
MS_LOG(INFO) << "MemcpyAsyncTask Distribute start.";
|
||||
MS_LOG(INFO) << "dst_max: " << task_info_->dst_max() << ", count: " << task_info_->count()
|
||||
<< ", kind: " << task_info_->kind();
|
||||
rtError_t rt_ret = rtMemcpyAsync(task_info_->dst(), task_info_->dst_max(), task_info_->src(), task_info_->count(),
|
||||
static_cast<rtMemcpyKind_t>(task_info_->kind()), stream_);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rt api rtMemcpyAsync failed, ret: " << std::hex << rt_ret;
|
||||
}
|
||||
MS_LOG(INFO) << "DistributeTask end";
|
||||
}
|
||||
|
||||
REGISTER_TASK(TaskInfoType::MEMCPY_ASYNC, MemcpyAsyncTask, MemcpyAsyncTaskInfo);
|
||||
} // namespace mindspore::ge::model_runner
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_MEMCPY_ASYNC_TASK_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_MEMCPY_ASYNC_TASK_H_
|
||||
|
||||
#include <memory>
|
||||
#include "runtime/device/ascend/ge_runtime/task/task.h"
|
||||
|
||||
namespace mindspore::ge::model_runner {
|
||||
class MemcpyAsyncTask : public TaskRepeater<MemcpyAsyncTaskInfo> {
|
||||
public:
|
||||
MemcpyAsyncTask(const ModelContext &model_context, const std::shared_ptr<MemcpyAsyncTaskInfo> &task_info);
|
||||
|
||||
~MemcpyAsyncTask() override;
|
||||
|
||||
void Distribute() override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<MemcpyAsyncTaskInfo> task_info_;
|
||||
rtStream_t stream_;
|
||||
};
|
||||
} // namespace mindspore::ge::model_runner
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_MEMCPY_ASYNC_TASK_H_
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/device/ascend/ge_runtime/task/profiler_task.h"
|
||||
#include "runtime/device/ascend/ge_runtime/task/task_factory.h"
|
||||
|
||||
namespace mindspore::ge::model_runner {
|
||||
ProfilerTask::ProfilerTask(const ModelContext &model_context, const std::shared_ptr<ProfilerTraceTaskInfo> &task_info)
|
||||
: TaskRepeater<ProfilerTraceTaskInfo>(model_context, task_info), task_info_(task_info), stream_(nullptr) {
|
||||
MS_EXCEPTION_IF_NULL(task_info);
|
||||
auto stream_list = model_context.stream_list();
|
||||
uint32_t stream_id = task_info->stream_id();
|
||||
MS_LOG(INFO) << "Stream list size: " << stream_list.size() << ", stream id: " << stream_id;
|
||||
if (stream_id >= stream_list.size()) {
|
||||
MS_LOG(EXCEPTION) << "Index: " << task_info->stream_id() << " >= stream_list.size(): " << stream_list.size();
|
||||
}
|
||||
stream_ = stream_list[stream_id];
|
||||
}
|
||||
|
||||
ProfilerTask::~ProfilerTask() {}
|
||||
|
||||
void ProfilerTask::Distribute() {
|
||||
MS_LOG(INFO) << "ProfilerTask Distribute start.";
|
||||
MS_LOG(INFO) << "log id = " << task_info_->log_id() << ", notify = " << task_info_->notify()
|
||||
<< ", flat = " << task_info_->flat();
|
||||
rtError_t rt_ret = rtProfilerTrace(task_info_->log_id(), task_info_->notify(), task_info_->flat(), stream_);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rt api rtProfilerTrace failed, ret: " << std::hex << rt_ret;
|
||||
}
|
||||
MS_LOG(INFO) << "DistributeTask end.";
|
||||
}
|
||||
|
||||
REGISTER_TASK(TaskInfoType::PROFILER_TRACE, ProfilerTask, ProfilerTraceTaskInfo);
|
||||
} // namespace mindspore::ge::model_runner
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_PROFILER_TASK_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_PROFILER_TASK_H_
|
||||
|
||||
#include <memory>
|
||||
#include "runtime/device/ascend/ge_runtime/task/task.h"
|
||||
|
||||
namespace mindspore::ge::model_runner {
|
||||
class ProfilerTask : public TaskRepeater<ProfilerTraceTaskInfo> {
|
||||
public:
|
||||
ProfilerTask(const ModelContext &model_context, const std::shared_ptr<ProfilerTraceTaskInfo> &task_info);
|
||||
|
||||
~ProfilerTask() override;
|
||||
|
||||
void Distribute() override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<ProfilerTraceTaskInfo> task_info_;
|
||||
rtStream_t stream_;
|
||||
};
|
||||
} // namespace mindspore::ge::model_runner
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_PROFILER_TASK_H_
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/device/ascend/ge_runtime/task/stream_active_task.h"
|
||||
#include "runtime/kernel.h"
|
||||
#include "runtime/device/ascend/ge_runtime/task/task_factory.h"
|
||||
|
||||
namespace mindspore::ge::model_runner {
|
||||
StreamActiveTask::StreamActiveTask(const ModelContext &model_context,
|
||||
const std::shared_ptr<StreamActiveTaskInfo> &task_info)
|
||||
: TaskRepeater<StreamActiveTaskInfo>(model_context, task_info),
|
||||
task_info_(task_info),
|
||||
stream_(nullptr),
|
||||
active_stream_(nullptr) {
|
||||
MS_EXCEPTION_IF_NULL(task_info);
|
||||
auto stream_list = model_context.stream_list();
|
||||
uint32_t stream_id = task_info->stream_id();
|
||||
uint32_t active_stream_id = task_info->active_stream_id();
|
||||
MS_LOG(INFO) << "Stream list size: " << stream_list.size() << ", stream id: " << stream_id
|
||||
<< ", active stream id: " << active_stream_id;
|
||||
if (stream_id >= stream_list.size() || active_stream_id >= stream_list.size()) {
|
||||
MS_LOG(EXCEPTION) << "Stream id invalid";
|
||||
}
|
||||
stream_ = stream_list[stream_id];
|
||||
active_stream_ = stream_list[active_stream_id];
|
||||
}
|
||||
|
||||
StreamActiveTask::~StreamActiveTask() {}
|
||||
|
||||
void StreamActiveTask::Distribute() {
|
||||
MS_LOG(INFO) << "Distribute start";
|
||||
MS_LOG(INFO) << "Stream " << task_info_->stream_id() << " active " << task_info_->active_stream_id();
|
||||
MS_EXCEPTION_IF_NULL(stream_);
|
||||
MS_EXCEPTION_IF_NULL(active_stream_);
|
||||
rtError_t rt_ret = rtStreamActive(active_stream_, stream_);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rt api rtStreamActive failed, ret: " << std::hex << rt_ret;
|
||||
}
|
||||
MS_LOG(INFO) << "DistributeTask end.";
|
||||
}
|
||||
|
||||
REGISTER_TASK(TaskInfoType::STREAM_ACTIVE, StreamActiveTask, StreamActiveTaskInfo);
|
||||
} // namespace mindspore::ge::model_runner
|
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_STREAM_ACTIVE_TASK_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_STREAM_ACTIVE_TASK_H_
|
||||
|
||||
#include <memory>
|
||||
#include "runtime/device/ascend/ge_runtime/task/task.h"
|
||||
|
||||
namespace mindspore::ge::model_runner {
|
||||
class StreamActiveTask : public TaskRepeater<StreamActiveTaskInfo> {
|
||||
public:
|
||||
StreamActiveTask(const ModelContext &model_context, const std::shared_ptr<StreamActiveTaskInfo> &task_info);
|
||||
|
||||
~StreamActiveTask() override;
|
||||
|
||||
void Distribute() override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<StreamActiveTaskInfo> task_info_;
|
||||
rtStream_t stream_;
|
||||
rtStream_t active_stream_;
|
||||
};
|
||||
} // namespace mindspore::ge::model_runner
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_STREAM_ACTIVE_TASK_H_
|
|
@ -0,0 +1,70 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/device/ascend/ge_runtime/task/stream_switch_task.h"
|
||||
#include "runtime/kernel.h"
|
||||
#include "runtime/device/ascend/ge_runtime/task/task_factory.h"
|
||||
|
||||
namespace mindspore::ge::model_runner {
|
||||
StreamSwitchTask::StreamSwitchTask(const ModelContext &model_context,
|
||||
const std::shared_ptr<StreamSwitchTaskInfo> &task_info)
|
||||
: TaskRepeater<StreamSwitchTaskInfo>(model_context, task_info),
|
||||
task_info_(task_info),
|
||||
stream_(nullptr),
|
||||
stream_list_() {
|
||||
MS_EXCEPTION_IF_NULL(task_info);
|
||||
stream_list_ = model_context.stream_list();
|
||||
if (stream_list_.size() == 1) {
|
||||
stream_ = stream_list_[0];
|
||||
} else if (stream_list_.size() > task_info->stream_id()) {
|
||||
stream_ = stream_list_[task_info->stream_id()];
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Index: " << task_info->stream_id() << " >= stream_list.size(): " << stream_list_.size();
|
||||
}
|
||||
}
|
||||
|
||||
StreamSwitchTask::~StreamSwitchTask() {}
|
||||
|
||||
void StreamSwitchTask::Distribute() {
|
||||
MS_LOG(INFO) << "Init StreamSwitchTask start.";
|
||||
MS_LOG(INFO) << "Stream " << task_info_->stream_id() << " active " << task_info_->true_stream_id();
|
||||
MS_EXCEPTION_IF_NULL(stream_);
|
||||
|
||||
if (static_cast<uint64_t>(task_info_->true_stream_id()) >= stream_list_.size()) {
|
||||
MS_LOG(EXCEPTION) << "true_stream_id " << task_info_->true_stream_id() << " must be less than stream_list_ size "
|
||||
<< stream_list_.size();
|
||||
}
|
||||
|
||||
void *input = reinterpret_cast<void *>(task_info_->input_addr());
|
||||
rtCondition_t cond = static_cast<rtCondition_t>(task_info_->cond());
|
||||
void *value = reinterpret_cast<void *>(task_info_->value_addr());
|
||||
rtStream_t true_stream = stream_list_[task_info_->true_stream_id()];
|
||||
rtSwitchDataType_t data_type = static_cast<rtSwitchDataType_t>(task_info_->data_type());
|
||||
|
||||
MS_LOG(INFO) << "InitStreamSwitchTask, cond: " << cond << ", trueStream: " << true_stream
|
||||
<< ", trueStreamID: " << task_info_->true_stream_id() << ", datatype: " << task_info_->data_type();
|
||||
|
||||
MS_LOG(INFO) << "StreamSwitchTask Distribute Start.";
|
||||
rtError_t rt_ret = rtStreamSwitchEx(input, cond, value, true_stream, stream_, data_type);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rt api rtStreamSwitchEx failed, ret: " << std::hex << rt_ret;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Distribute StreamSwitch success";
|
||||
}
|
||||
|
||||
REGISTER_TASK(TaskInfoType::STREAM_SWITCH, StreamSwitchTask, StreamSwitchTaskInfo);
|
||||
} // namespace mindspore::ge::model_runner
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_STREAM_SWITCH_TASK_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_STREAM_SWITCH_TASK_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "runtime/device/ascend/ge_runtime/task/task.h"
|
||||
|
||||
namespace mindspore::ge::model_runner {
|
||||
class StreamSwitchTask : public TaskRepeater<StreamSwitchTaskInfo> {
|
||||
public:
|
||||
StreamSwitchTask(const ModelContext &model_context, const std::shared_ptr<StreamSwitchTaskInfo> &task_info);
|
||||
|
||||
~StreamSwitchTask() override;
|
||||
|
||||
void Distribute() override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<StreamSwitchTaskInfo> task_info_;
|
||||
|
||||
void *stream_;
|
||||
std::vector<rtStream_t> stream_list_;
|
||||
};
|
||||
} // namespace mindspore::ge::model_runner
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_STREAM_SWITCH_TASK_H_
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TASK_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TASK_H_
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "runtime/device/ascend/ge_runtime/model_context.h"
|
||||
#include "runtime/device/ascend/ge_runtime/task_info.h"
|
||||
|
||||
namespace mindspore::ge::model_runner {
|
||||
class Task {
|
||||
public:
|
||||
Task() {}
|
||||
|
||||
virtual ~Task() {}
|
||||
|
||||
virtual void Distribute() = 0;
|
||||
|
||||
virtual void *Args() { return nullptr; }
|
||||
|
||||
virtual std::string task_name() const { return ""; }
|
||||
};
|
||||
|
||||
template <class T>
|
||||
class TaskRepeater : public Task {
|
||||
static_assert(std::is_base_of<TaskInfo, T>(), "Wrong TaskInfo Type!");
|
||||
|
||||
public:
|
||||
TaskRepeater(const ModelContext &model_context, const std::shared_ptr<T> &task_info) {}
|
||||
|
||||
virtual ~TaskRepeater() {}
|
||||
|
||||
virtual void Distribute() = 0;
|
||||
};
|
||||
} // namespace mindspore::ge::model_runner
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TASK_H_
|
|
@ -0,0 +1,84 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TASK_FACTORY_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TASK_FACTORY_H_
|
||||
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include "runtime/device/ascend/ge_runtime/task_info.h"
|
||||
#include "mindspore/core/utils/log_adapter.h"
|
||||
|
||||
namespace mindspore::ge::model_runner {
|
||||
class Task;
|
||||
class ModelContext;
|
||||
using TASK_CREATOR_FUN = std::function<std::shared_ptr<Task>(const ModelContext &, std::shared_ptr<TaskInfo>)>;
|
||||
|
||||
class TaskFactory {
|
||||
private:
|
||||
TaskFactory() {}
|
||||
~TaskFactory() {}
|
||||
void RegisterCreator(const TaskInfoType &type, const TASK_CREATOR_FUN &func) {
|
||||
if (creator_map_.find(type) != creator_map_.end()) {
|
||||
MS_LOG(WARNING) << "Creator type " << type << " already exist.";
|
||||
}
|
||||
creator_map_[type] = func;
|
||||
}
|
||||
|
||||
std::map<TaskInfoType, TASK_CREATOR_FUN> creator_map_;
|
||||
|
||||
public:
|
||||
static TaskFactory &GetInstance() {
|
||||
static TaskFactory instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
std::shared_ptr<Task> Create(const ModelContext &model_context, const std::shared_ptr<TaskInfo> &task_info) const {
|
||||
if (task_info == nullptr) {
|
||||
MS_LOG(ERROR) << "task_info is null.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto iter = creator_map_.find(task_info->type());
|
||||
if (iter == creator_map_.end()) {
|
||||
MS_LOG(ERROR) << "Unknown task type " << task_info->type();
|
||||
return nullptr;
|
||||
}
|
||||
return iter->second(model_context, task_info);
|
||||
}
|
||||
|
||||
class Register {
|
||||
public:
|
||||
Register(const TaskInfoType &type, const TASK_CREATOR_FUN &func) {
|
||||
MS_LOG(DEBUG) << "register type " << type;
|
||||
TaskFactory::GetInstance().RegisterCreator(type, func);
|
||||
}
|
||||
|
||||
~Register() {}
|
||||
};
|
||||
};
|
||||
|
||||
#define REGISTER_TASK(type, task_clazz, task_info_clazz) \
|
||||
TaskFactory::Register g_##task_clazz##_register( \
|
||||
type, [](const ModelContext &model_context, const std::shared_ptr<TaskInfo> &task_info) -> std::shared_ptr<Task> { \
|
||||
std::shared_ptr<task_info_clazz> concrete_task_info = std::static_pointer_cast<task_info_clazz>(task_info); \
|
||||
return std::make_shared<task_clazz>(model_context, concrete_task_info); \
|
||||
});
|
||||
|
||||
} // namespace mindspore::ge::model_runner
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TASK_FACTORY_H_
|
|
@ -0,0 +1,97 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/device/ascend/ge_runtime/task/tbe_task.h"
|
||||
#include <vector>
|
||||
#include "runtime/mem.h"
|
||||
#include "runtime/kernel.h"
|
||||
#include "runtime/device/ascend/ge_runtime/task/task_factory.h"
|
||||
|
||||
namespace mindspore::ge::model_runner {
|
||||
TbeTask::TbeTask(const ModelContext &model_context, const std::shared_ptr<TbeTaskInfo> &task_info)
|
||||
: TaskRepeater<TbeTaskInfo>(model_context, task_info),
|
||||
task_info_(task_info),
|
||||
stream_(nullptr),
|
||||
stub_func_(nullptr),
|
||||
args_(nullptr) {
|
||||
MS_EXCEPTION_IF_NULL(task_info);
|
||||
|
||||
auto stream_list = model_context.stream_list();
|
||||
if (stream_list.size() == 1) {
|
||||
stream_ = stream_list[0];
|
||||
} else if (stream_list.size() > task_info->stream_id()) {
|
||||
stream_ = stream_list[task_info->stream_id()];
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Index: " << task_info->stream_id() << " >= stream_list.size(): " << stream_list.size();
|
||||
}
|
||||
}
|
||||
|
||||
TbeTask::~TbeTask() {
|
||||
if (args_ != nullptr) {
|
||||
rtError_t rt_ret = rtFree(args_);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Call rt api rtFree failed, ret: " << std::hex << rt_ret;
|
||||
}
|
||||
args_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void TbeTask::Distribute() {
|
||||
MS_LOG(INFO) << "InitTbeTask start.";
|
||||
MS_EXCEPTION_IF_NULL(stream_);
|
||||
// Get stub_func
|
||||
if (task_info_->stub_func().empty()) {
|
||||
MS_LOG(EXCEPTION) << "kernel_info->stub_func is empty!";
|
||||
}
|
||||
|
||||
rtError_t rt_ret = rtGetFunctionByName(const_cast<char *>(task_info_->stub_func().c_str()), &stub_func_);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rt api rtGetFunctionByName failed, ret: " << std::hex << rt_ret;
|
||||
}
|
||||
MS_LOG(INFO) << "TbeTask: stub_func = " << task_info_->stub_func();
|
||||
|
||||
// Get args
|
||||
std::vector<void *> tensor_device_addrs;
|
||||
tensor_device_addrs.insert(tensor_device_addrs.end(), task_info_->input_data_addrs().begin(),
|
||||
task_info_->input_data_addrs().end());
|
||||
tensor_device_addrs.insert(tensor_device_addrs.end(), task_info_->output_data_addrs().begin(),
|
||||
task_info_->output_data_addrs().end());
|
||||
tensor_device_addrs.insert(tensor_device_addrs.end(), task_info_->workspace_addrs().begin(),
|
||||
task_info_->workspace_addrs().end());
|
||||
auto args_size = static_cast<uint32_t>(tensor_device_addrs.size() * sizeof(void *));
|
||||
|
||||
rt_ret = rtMalloc(&args_, args_size, RT_MEMORY_HBM);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rt api rtMalloc failed, ret: " << std::hex << rt_ret << " mem size " << args_size;
|
||||
}
|
||||
|
||||
rt_ret = rtMemcpy(args_, args_size, reinterpret_cast<void *>(tensor_device_addrs.data()), args_size,
|
||||
RT_MEMCPY_HOST_TO_DEVICE);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rt api rtMemcpy failed, ret: " << std::hex << rt_ret;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "DistributeTbeTask start.";
|
||||
auto dump_flag = task_info_->dump_flag() ? RT_KERNEL_DUMPFLAG : RT_KERNEL_DEFAULT;
|
||||
rt_ret = rtKernelLaunchWithFlag(stub_func_, task_info_->block_dim(), args_, args_size, nullptr, stream_, dump_flag);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rt api rtKernelLaunch failed, ret: " << std::hex << rt_ret << " mem size " << args_size;
|
||||
}
|
||||
MS_LOG(INFO) << "[DataDump] task name: " << task_info_->op_name() << " dump_flag: " << dump_flag;
|
||||
}
|
||||
|
||||
REGISTER_TASK(TaskInfoType::TBE, TbeTask, TbeTaskInfo);
|
||||
} // namespace mindspore::ge::model_runner
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TBE_TASK_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TBE_TASK_H_
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "runtime/device/ascend/ge_runtime/task/task.h"
|
||||
|
||||
namespace mindspore::ge::model_runner {
|
||||
class TbeTask : public TaskRepeater<TbeTaskInfo> {
|
||||
public:
|
||||
TbeTask(const ModelContext &model_context, const std::shared_ptr<TbeTaskInfo> &task_info);
|
||||
|
||||
~TbeTask() override;
|
||||
|
||||
void Distribute() override;
|
||||
|
||||
void *Args() override { return args_; }
|
||||
|
||||
std::string task_name() const override { return task_info_->op_name(); }
|
||||
|
||||
private:
|
||||
std::shared_ptr<TbeTaskInfo> task_info_;
|
||||
void *stream_;
|
||||
void *stub_func_;
|
||||
void *args_;
|
||||
};
|
||||
} // namespace mindspore::ge::model_runner
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TBE_TASK_H_
|
|
@ -0,0 +1,364 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_INFO_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_INFO_H_
|
||||
|
||||
#include <stdint.h>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace mindspore::ge::model_runner {
|
||||
enum TaskInfoType {
|
||||
CCE = 0,
|
||||
TBE,
|
||||
AICPU,
|
||||
LABEL_SET,
|
||||
LABEL_SWITCH,
|
||||
LABEL_GOTO,
|
||||
EVENT_RECORD,
|
||||
EVENT_WAIT,
|
||||
FUSION_START,
|
||||
FUSION_END,
|
||||
HCCL,
|
||||
PROFILER_TRACE,
|
||||
MEMCPY_ASYNC,
|
||||
STREAM_SWITCH,
|
||||
STREAM_ACTIVE,
|
||||
// Insert new task type here
|
||||
REVSERVED = 23
|
||||
};
|
||||
|
||||
class TaskInfo {
|
||||
public:
|
||||
virtual ~TaskInfo() {}
|
||||
uint32_t stream_id() const { return stream_id_; }
|
||||
TaskInfoType type() const { return type_; }
|
||||
std::string op_name() const { return op_name_; }
|
||||
bool dump_flag() const { return dump_flag_; }
|
||||
|
||||
protected:
|
||||
TaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, bool dump_flag)
|
||||
: op_name_(op_name), stream_id_(stream_id), type_(type), dump_flag_(dump_flag) {}
|
||||
|
||||
private:
|
||||
std::string op_name_;
|
||||
uint32_t stream_id_;
|
||||
TaskInfoType type_;
|
||||
bool dump_flag_;
|
||||
};
|
||||
|
||||
class TbeTaskInfo : public TaskInfo {
|
||||
public:
|
||||
TbeTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string &stub_func, uint32_t block_dim,
|
||||
const std::vector<uint8_t> &args, uint32_t args_size, const std::vector<uint8_t> &sm_desc, void *binary,
|
||||
uint32_t binary_size, const std::vector<uint8_t> &meta_data, const std::vector<void *> &input_data_addrs,
|
||||
const std::vector<void *> &output_data_addrs, const std::vector<void *> &workspace_addrs, bool dump_flag)
|
||||
: TaskInfo(op_name, stream_id, TaskInfoType::TBE, dump_flag),
|
||||
stub_func_(stub_func),
|
||||
block_dim_(block_dim),
|
||||
args_(args),
|
||||
args_size_(args_size),
|
||||
sm_desc_(sm_desc),
|
||||
binary_(binary),
|
||||
binary_size_(binary_size),
|
||||
meta_data_(meta_data),
|
||||
input_data_addrs_(input_data_addrs),
|
||||
output_data_addrs_(output_data_addrs),
|
||||
workspace_addrs_(workspace_addrs) {}
|
||||
~TbeTaskInfo() override {}
|
||||
|
||||
const std::string &stub_func() const { return stub_func_; }
|
||||
uint32_t block_dim() const { return block_dim_; }
|
||||
const std::vector<uint8_t> &args() const { return args_; }
|
||||
uint32_t args_size() const { return args_size_; }
|
||||
const std::vector<uint8_t> &sm_desc() const { return sm_desc_; }
|
||||
void *binary() const { return binary_; }
|
||||
uint32_t binary_size() const { return binary_size_; }
|
||||
const std::vector<uint8_t> &meta_data() const { return meta_data_; }
|
||||
const std::vector<void *> &input_data_addrs() const { return input_data_addrs_; }
|
||||
const std::vector<void *> &output_data_addrs() const { return output_data_addrs_; }
|
||||
const std::vector<void *> &workspace_addrs() const { return workspace_addrs_; }
|
||||
|
||||
void SetBinary(void *binary, uint32_t binary_size) {
|
||||
binary_ = binary;
|
||||
binary_size_ = binary_size;
|
||||
}
|
||||
|
||||
private:
|
||||
std::string stub_func_;
|
||||
uint32_t block_dim_;
|
||||
std::vector<uint8_t> args_;
|
||||
uint32_t args_size_;
|
||||
std::vector<uint8_t> sm_desc_;
|
||||
void *binary_;
|
||||
uint32_t binary_size_;
|
||||
std::vector<uint8_t> meta_data_;
|
||||
std::vector<void *> input_data_addrs_;
|
||||
std::vector<void *> output_data_addrs_;
|
||||
std::vector<void *> workspace_addrs_;
|
||||
};
|
||||
|
||||
class AicpuTaskInfo : public TaskInfo {
|
||||
public:
|
||||
AicpuTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string &so_name,
|
||||
const std::string &kernel_name, const std::string &node_def, const std::string &ext_info,
|
||||
const std::vector<void *> &input_data_addrs, const std::vector<void *> &output_data_addrs,
|
||||
bool dump_flag)
|
||||
: TaskInfo(op_name, stream_id, TaskInfoType::AICPU, dump_flag),
|
||||
so_name_(so_name),
|
||||
kernel_name_(kernel_name),
|
||||
node_def_(node_def),
|
||||
ext_info_(ext_info),
|
||||
input_data_addrs_(input_data_addrs),
|
||||
output_data_addrs_(output_data_addrs) {}
|
||||
~AicpuTaskInfo() override {}
|
||||
|
||||
const std::string &so_name() const { return so_name_; }
|
||||
const std::string &kernel_name() const { return kernel_name_; }
|
||||
const std::string &node_def() const { return node_def_; }
|
||||
const std::vector<void *> &input_data_addrs() const { return input_data_addrs_; }
|
||||
const std::vector<void *> &output_data_addrs() const { return output_data_addrs_; }
|
||||
const std::string &ext_info() const { return ext_info_; }
|
||||
|
||||
private:
|
||||
std::string so_name_;
|
||||
std::string kernel_name_;
|
||||
std::string node_def_;
|
||||
std::string ext_info_;
|
||||
std::vector<void *> input_data_addrs_;
|
||||
std::vector<void *> output_data_addrs_;
|
||||
};
|
||||
|
||||
class LabelSetTaskInfo : public TaskInfo {
|
||||
public:
|
||||
LabelSetTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id)
|
||||
: TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SET, false), label_id_(label_id) {}
|
||||
~LabelSetTaskInfo() override {}
|
||||
uint32_t label_id() const { return label_id_; }
|
||||
|
||||
private:
|
||||
uint32_t label_id_;
|
||||
};
|
||||
|
||||
class LabelGotoTaskInfo : public TaskInfo {
|
||||
public:
|
||||
LabelGotoTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id)
|
||||
: TaskInfo(op_name, stream_id, TaskInfoType::LABEL_GOTO, false), label_id_(label_id) {}
|
||||
~LabelGotoTaskInfo() override {}
|
||||
uint32_t label_id() const { return label_id_; }
|
||||
|
||||
private:
|
||||
uint32_t label_id_;
|
||||
};
|
||||
|
||||
class LabelSwitchTaskInfo : public TaskInfo {
|
||||
public:
|
||||
LabelSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_size,
|
||||
const std::vector<uint32_t> &label_list, void *cond)
|
||||
: TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SWITCH, false),
|
||||
label_size_(label_size),
|
||||
label_list_(label_list),
|
||||
cond_(cond) {}
|
||||
~LabelSwitchTaskInfo() override {}
|
||||
uint32_t label_size() const { return label_size_; }
|
||||
const std::vector<uint32_t> &label_list() const { return label_list_; }
|
||||
void *cond() const { return cond_; }
|
||||
|
||||
private:
|
||||
uint32_t label_size_;
|
||||
std::vector<uint32_t> label_list_;
|
||||
void *cond_;
|
||||
};
|
||||
|
||||
class EventTaskInfo : public TaskInfo {
|
||||
public:
|
||||
uint32_t event_id() const { return event_id_; }
|
||||
|
||||
protected:
|
||||
EventTaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, uint32_t event_id)
|
||||
: TaskInfo(op_name, stream_id, type, false), event_id_(event_id) {}
|
||||
~EventTaskInfo() override {}
|
||||
|
||||
uint32_t event_id_;
|
||||
};
|
||||
|
||||
class EventRecordTaskInfo : public EventTaskInfo {
|
||||
public:
|
||||
EventRecordTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id)
|
||||
: EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_RECORD, event_id) {}
|
||||
~EventRecordTaskInfo() override {}
|
||||
};
|
||||
|
||||
class EventWaitTaskInfo : public EventTaskInfo {
|
||||
public:
|
||||
EventWaitTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id)
|
||||
: EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_WAIT, event_id) {}
|
||||
~EventWaitTaskInfo() override {}
|
||||
};
|
||||
|
||||
class FusionStartTaskInfo : public TaskInfo {
|
||||
public:
|
||||
explicit FusionStartTaskInfo(const std::string &op_name, uint32_t stream_id)
|
||||
: TaskInfo(op_name, stream_id, TaskInfoType::FUSION_START, false) {}
|
||||
~FusionStartTaskInfo() override {}
|
||||
};
|
||||
|
||||
class FusionEndTaskInfo : public TaskInfo {
|
||||
public:
|
||||
explicit FusionEndTaskInfo(const std::string &op_name, uint32_t stream_id)
|
||||
: TaskInfo(op_name, stream_id, TaskInfoType::FUSION_END, false) {}
|
||||
~FusionEndTaskInfo() override {}
|
||||
};
|
||||
|
||||
class HcclTaskInfo : public TaskInfo {
|
||||
public:
|
||||
HcclTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string hccl_type, void *input_data_addr,
|
||||
void *output_data_addr, void *workspace_addr, int64_t workspace_size, int64_t hccl_stream_num,
|
||||
const std::vector<uint8_t> &private_def, void *ops_kernel_store, int32_t count, int64_t root_id,
|
||||
int64_t op_type, int64_t data_type, const std::string &group, bool dump_flag)
|
||||
: TaskInfo(op_name, stream_id, TaskInfoType::HCCL, dump_flag),
|
||||
hccl_type_(hccl_type),
|
||||
input_data_addr_(input_data_addr),
|
||||
output_data_addr_(output_data_addr),
|
||||
workspace_addr_(workspace_addr),
|
||||
workspace_size_(workspace_size),
|
||||
hccl_stream_num_(hccl_stream_num),
|
||||
private_def_(private_def),
|
||||
ops_kernel_store_(ops_kernel_store),
|
||||
count_(count),
|
||||
root_id_(root_id),
|
||||
op_type_(op_type),
|
||||
data_type_(data_type),
|
||||
group_(group) {}
|
||||
~HcclTaskInfo() override {}
|
||||
|
||||
const std::string &hccl_type() const { return hccl_type_; }
|
||||
void *input_data_addr() const { return input_data_addr_; }
|
||||
void *output_data_addr() const { return output_data_addr_; }
|
||||
void *workspace_addr() const { return workspace_addr_; }
|
||||
int64_t workspace_size() const { return workspace_size_; }
|
||||
int64_t hccl_stream_num() const { return hccl_stream_num_; }
|
||||
const std::vector<uint8_t> &private_def() const { return private_def_; }
|
||||
void *ops_kernel_store() const { return ops_kernel_store_; }
|
||||
int32_t count() const { return count_; }
|
||||
int64_t root_id() const { return root_id_; }
|
||||
int64_t op_type() const { return op_type_; }
|
||||
int64_t data_type() const { return data_type_; }
|
||||
const std::string &group() const { return group_; }
|
||||
|
||||
private:
|
||||
std::string hccl_type_;
|
||||
void *input_data_addr_;
|
||||
void *output_data_addr_;
|
||||
void *workspace_addr_;
|
||||
int64_t workspace_size_;
|
||||
int64_t hccl_stream_num_;
|
||||
std::vector<uint8_t> private_def_;
|
||||
void *ops_kernel_store_;
|
||||
int32_t count_;
|
||||
int64_t root_id_;
|
||||
int64_t op_type_;
|
||||
int64_t data_type_;
|
||||
std::string group_;
|
||||
};
|
||||
|
||||
class ProfilerTraceTaskInfo : public TaskInfo {
|
||||
public:
|
||||
ProfilerTraceTaskInfo(const std::string &op_name, uint32_t stream_id, uint64_t log_id, bool notify, uint32_t flat)
|
||||
: TaskInfo(op_name, stream_id, TaskInfoType::PROFILER_TRACE, false),
|
||||
log_id_(log_id),
|
||||
notify_(notify),
|
||||
flat_(flat) {}
|
||||
~ProfilerTraceTaskInfo() override {}
|
||||
|
||||
uint64_t log_id() const { return log_id_; }
|
||||
bool notify() const { return notify_; }
|
||||
uint32_t flat() const { return flat_; }
|
||||
|
||||
private:
|
||||
uint64_t log_id_;
|
||||
bool notify_;
|
||||
uint32_t flat_;
|
||||
};
|
||||
|
||||
class MemcpyAsyncTaskInfo : public TaskInfo {
|
||||
public:
|
||||
MemcpyAsyncTaskInfo(const std::string &op_name, uint32_t stream_id, void *dst, uint64_t dst_max, void *src,
|
||||
uint64_t count, uint32_t kind, bool dump_flag)
|
||||
: TaskInfo(op_name, stream_id, TaskInfoType::MEMCPY_ASYNC, dump_flag),
|
||||
dst_(dst),
|
||||
dst_max_(dst_max),
|
||||
src_(src),
|
||||
count_(count),
|
||||
kind_(kind) {}
|
||||
~MemcpyAsyncTaskInfo() override {}
|
||||
|
||||
void *dst() const { return dst_; }
|
||||
uint64_t dst_max() const { return dst_max_; }
|
||||
void *src() const { return src_; }
|
||||
uint64_t count() const { return count_; }
|
||||
uint32_t kind() const { return kind_; }
|
||||
|
||||
private:
|
||||
void *dst_;
|
||||
uint64_t dst_max_;
|
||||
void *src_;
|
||||
uint64_t count_;
|
||||
int32_t kind_;
|
||||
};
|
||||
|
||||
class StreamSwitchTaskInfo : public TaskInfo {
|
||||
public:
|
||||
StreamSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, int64_t true_stream_id, void *input_addr,
|
||||
void *value_addr, int64_t cond, int64_t data_type)
|
||||
: TaskInfo(op_name, stream_id, TaskInfoType::STREAM_SWITCH, false),
|
||||
true_stream_id_(true_stream_id),
|
||||
input_addr_(input_addr),
|
||||
value_addr_(value_addr),
|
||||
cond_(cond),
|
||||
data_type_(data_type) {}
|
||||
~StreamSwitchTaskInfo() override {}
|
||||
|
||||
int64_t true_stream_id() const { return true_stream_id_; }
|
||||
void *input_addr() const { return input_addr_; }
|
||||
void *value_addr() const { return value_addr_; }
|
||||
int64_t cond() const { return cond_; }
|
||||
int64_t data_type() const { return data_type_; }
|
||||
|
||||
private:
|
||||
int64_t true_stream_id_;
|
||||
void *input_addr_;
|
||||
void *value_addr_;
|
||||
int64_t cond_;
|
||||
int64_t data_type_;
|
||||
};
|
||||
|
||||
class StreamActiveTaskInfo : public TaskInfo {
|
||||
public:
|
||||
StreamActiveTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t active_stream_id)
|
||||
: TaskInfo(op_name, stream_id, TaskInfoType::STREAM_ACTIVE, false), active_stream_id_(active_stream_id) {}
|
||||
~StreamActiveTaskInfo() override {}
|
||||
|
||||
uint32_t active_stream_id() const { return active_stream_id_; }
|
||||
|
||||
private:
|
||||
uint32_t active_stream_id_;
|
||||
};
|
||||
} // namespace mindspore::ge::model_runner
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_INFO_H_
|
|
@ -25,7 +25,7 @@
|
|||
#include "runtime/device/kernel_runtime.h"
|
||||
#include "ir/anf.h"
|
||||
#include "backend/kernel_compiler/ascend_kernel_mod.h"
|
||||
#include "framework/ge_runtime/task_info.h"
|
||||
#include "runtime/device/ascend/ge_runtime/task_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
|
|
|
@ -134,6 +134,7 @@ enum SubModuleId : int {
|
|||
SM_HCCL_ADPT, // Hccl Adapter
|
||||
SM_MINDQUANTUM, // MindQuantum
|
||||
SM_RUNTIME_FRAMEWORK, // Runtime framework
|
||||
SM_GE, // GraphEngine
|
||||
NUM_SUBMODUES // number of submodules
|
||||
};
|
||||
|
||||
|
@ -142,34 +143,35 @@ enum SubModuleId : int {
|
|||
#endif
|
||||
|
||||
static const char *SUB_MODULE_NAMES[NUM_SUBMODUES] = {
|
||||
"UNKNOWN", // SM_UNKNOWN
|
||||
"CORE", // SM_CORE
|
||||
"ANALYZER", // SM_ANALYZER
|
||||
"COMMON", // SM_COMMON
|
||||
"DEBUG", // SM_DEBUG
|
||||
"OFFLINE_DEBUG", // SM_OFFLINE_DEBUG
|
||||
"DEVICE", // SM_DEVICE
|
||||
"GE_ADPT", // SM_GE_ADPT
|
||||
"IR", // SM_IR
|
||||
"KERNEL", // SM_KERNEL
|
||||
"MD", // SM_MD
|
||||
"ME", // SM_ME
|
||||
"EXPRESS", // SM_EXPRESS
|
||||
"OPTIMIZER", // SM_OPTIMIZER
|
||||
"PARALLEL", // SM_PARALLEL
|
||||
"PARSER", // SM_PARSER
|
||||
"PIPELINE", // SM_PIPELINE
|
||||
"PRE_ACT", // SM_PRE_ACT
|
||||
"PYNATIVE", // SM_PYNATIVE
|
||||
"SESSION", // SM_SESSION
|
||||
"UTILS", // SM_UTILS
|
||||
"VM", // SM_VM
|
||||
"PROFILER", // SM_PROFILER
|
||||
"PS", // SM_PS
|
||||
"LITE", // SM_LITE
|
||||
"HCCL_ADPT", // SM_HCCL_ADPT
|
||||
"MINDQUANTUM", // SM_MINDQUANTUM
|
||||
"RUNTIME_FRAMEWORK" // SM_RUNTIME_FRAMEWORK
|
||||
"UNKNOWN", // SM_UNKNOWN
|
||||
"CORE", // SM_CORE
|
||||
"ANALYZER", // SM_ANALYZER
|
||||
"COMMON", // SM_COMMON
|
||||
"DEBUG", // SM_DEBUG
|
||||
"OFFLINE_DEBUG", // SM_OFFLINE_DEBUG
|
||||
"DEVICE", // SM_DEVICE
|
||||
"GE_ADPT", // SM_GE_ADPT
|
||||
"IR", // SM_IR
|
||||
"KERNEL", // SM_KERNEL
|
||||
"MD", // SM_MD
|
||||
"ME", // SM_ME
|
||||
"EXPRESS", // SM_EXPRESS
|
||||
"OPTIMIZER", // SM_OPTIMIZER
|
||||
"PARALLEL", // SM_PARALLEL
|
||||
"PARSER", // SM_PARSER
|
||||
"PIPELINE", // SM_PIPELINE
|
||||
"PRE_ACT", // SM_PRE_ACT
|
||||
"PYNATIVE", // SM_PYNATIVE
|
||||
"SESSION", // SM_SESSION
|
||||
"UTILS", // SM_UTILS
|
||||
"VM", // SM_VM
|
||||
"PROFILER", // SM_PROFILER
|
||||
"PS", // SM_PS
|
||||
"LITE", // SM_LITE
|
||||
"HCCL_ADPT", // SM_HCCL_ADPT
|
||||
"MINDQUANTUM", // SM_MINDQUANTUM
|
||||
"RUNTIME_FRAMEWORK", // SM_RUNTIME_FRAMEWORK
|
||||
"GE", // SM_GE
|
||||
};
|
||||
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
|
|
|
@ -23,6 +23,7 @@ include_directories(${CMAKE_SOURCE_DIR}/mindspore/core)
|
|||
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/stub/runtime/)
|
||||
include_directories(${CMAKE_BINARY_DIR})
|
||||
include_directories(${CMAKE_BINARY_DIR}/proto/ge)
|
||||
include_directories(${CUDA_INCLUDE_DIRS})
|
||||
MESSAGE("check ut_test ${CMAKE_BINARY_DIR}")
|
||||
|
||||
|
@ -104,6 +105,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
"../../../mindspore/ccsrc/runtime/device/bucket.cc"
|
||||
"../../../mindspore/ccsrc/runtime/device/launch_kernel.cc"
|
||||
"../../../mindspore/ccsrc/runtime/device/ascend/profiling/*.cc"
|
||||
"../../../mindspore/ccsrc/runtime/device/ascend/ge_runtime/*.cc"
|
||||
"../../../mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc"
|
||||
"../../../mindspore/ccsrc/runtime/device/ascend/ascend_launch_kernel.cc"
|
||||
"../../../mindspore/ccsrc/runtime/device/ascend/ascend_launch_mul.cc"
|
||||
|
|
|
@ -0,0 +1,473 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <memory>
|
||||
#include "common/common_test.h"
|
||||
#define private public
|
||||
#include "runtime/device/ascend/ge_runtime/model_runner.h"
|
||||
#include "runtime/device/ascend/ge_runtime/runtime_model.h"
|
||||
#include "runtime/device/ascend/ge_runtime/task/task_factory.h"
|
||||
#include "runtime/device/ascend/ge_runtime/task/aicpu_task.h"
|
||||
#include "runtime/device/ascend/ge_runtime/task/event_record_task.h"
|
||||
#include "runtime/device/ascend/ge_runtime/task/event_wait_task.h"
|
||||
#include "runtime/device/ascend/ge_runtime/task/hccl_task.h"
|
||||
#include "runtime/device/ascend/ge_runtime/task/label_goto_task.h"
|
||||
#include "runtime/device/ascend/ge_runtime/task/label_manager.h"
|
||||
#include "runtime/device/ascend/ge_runtime/task/label_set_task.h"
|
||||
#include "runtime/device/ascend/ge_runtime/task/label_switch_task.h"
|
||||
#include "runtime/device/ascend/ge_runtime/task/memcpy_async_task.h"
|
||||
#include "runtime/device/ascend/ge_runtime/task/profiler_task.h"
|
||||
#include "runtime/device/ascend/ge_runtime/task/stream_active_task.h"
|
||||
#include "runtime/device/ascend/ge_runtime/task/stream_switch_task.h"
|
||||
#include "runtime/device/ascend/ge_runtime/task/tbe_task.h"
|
||||
#undef private
|
||||
#include "common/opskernel/ops_kernel_info_store.h"
|
||||
|
||||
using namespace mindspore::ge::model_runner;
|
||||
using namespace testing;
|
||||
|
||||
class MockOpsKernelInfoStore : public ge::OpsKernelInfoStore {
|
||||
public:
|
||||
ge::Status Initialize(const map<string, string> &) override { return ge::SUCCESS; }
|
||||
ge::Status Finalize() override { return ge::SUCCESS; }
|
||||
void GetAllOpsKernelInfo(std::map<string, ge::OpInfo> &infos) const override {}
|
||||
bool CheckSupported(const ge::OpDescPtr &opDescPtr, std::string &un_supported_reason) const override { return true; }
|
||||
ge::Status LoadTask(ge::GETaskInfo &task) override { return ge::SUCCESS; }
|
||||
};
|
||||
|
||||
namespace mindspore {
|
||||
class TestAscendGeRuntime : public UT::Common {
|
||||
public:
|
||||
TestAscendGeRuntime() {}
|
||||
|
||||
private:
|
||||
void TearDown() override {
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(HcclTask::model_stream_mapping_mutex_);
|
||||
HcclTask::model_stream_mapping_.clear();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(TestAscendGeRuntime, test_task_create_null_task_info_failed) {
|
||||
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
|
||||
{reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(1)},
|
||||
{reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
|
||||
ASSERT_TRUE(TaskFactory::GetInstance().Create(model_context, nullptr) == nullptr);
|
||||
}
|
||||
|
||||
TEST_F(TestAscendGeRuntime, test_aicpu_task_create_one_stream_success) {
|
||||
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
|
||||
{reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
|
||||
{reinterpret_cast<rtEvent_t>(1)});
|
||||
std::shared_ptr<TaskInfo> aicpu_task_info = std::make_shared<AicpuTaskInfo>(
|
||||
"op_name", 0, "so_name", "kernel_name", "node_def", "ext_info", std::vector<void *>{reinterpret_cast<void *>(1)},
|
||||
std::vector<void *>{reinterpret_cast<void *>(1)}, true);
|
||||
std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, aicpu_task_info);
|
||||
ASSERT_TRUE(std::dynamic_pointer_cast<AicpuTask>(task) != nullptr);
|
||||
ASSERT_NO_THROW(task->Distribute());
|
||||
}
|
||||
|
||||
TEST_F(TestAscendGeRuntime, test_aicpu_task_create_multi_stream_success) {
|
||||
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
|
||||
{reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(1)},
|
||||
{reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
|
||||
std::shared_ptr<TaskInfo> aicpu_task_info = std::make_shared<AicpuTaskInfo>(
|
||||
"op_name", 0, "so_name", "kernel_name", "node_def", "", std::vector<void *>{reinterpret_cast<void *>(1)},
|
||||
std::vector<void *>{reinterpret_cast<void *>(1)}, true);
|
||||
std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, aicpu_task_info);
|
||||
ASSERT_TRUE(std::dynamic_pointer_cast<AicpuTask>(task) != nullptr);
|
||||
ASSERT_NO_THROW(task->Distribute());
|
||||
}
|
||||
|
||||
TEST_F(TestAscendGeRuntime, test_aicpu_task_create_invalid_stream_id_failed) {
|
||||
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
|
||||
{reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(1)},
|
||||
{reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
|
||||
std::shared_ptr<TaskInfo> aicpu_task_info = std::make_shared<AicpuTaskInfo>(
|
||||
"op_name", 5, "so_name", "kernel_name", "node_def", "", std::vector<void *>{reinterpret_cast<void *>(1)},
|
||||
std::vector<void *>{reinterpret_cast<void *>(1)}, true);
|
||||
ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, aicpu_task_info));
|
||||
}
|
||||
|
||||
TEST_F(TestAscendGeRuntime, test_event_record_task_create_success) {
|
||||
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
|
||||
{reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
|
||||
{reinterpret_cast<rtEvent_t>(1)});
|
||||
std::shared_ptr<TaskInfo> event_record_task_info = std::make_shared<EventRecordTaskInfo>("op_name", 0, 0);
|
||||
std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, event_record_task_info);
|
||||
ASSERT_TRUE(std::dynamic_pointer_cast<EventRecordTask>(task) != nullptr);
|
||||
ASSERT_NO_THROW(task->Distribute());
|
||||
}
|
||||
|
||||
TEST_F(TestAscendGeRuntime, test_event_record_task_create_invalid_event_id_failed) {
|
||||
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
|
||||
{reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
|
||||
{reinterpret_cast<rtEvent_t>(1)});
|
||||
std::shared_ptr<TaskInfo> event_record_task_info = std::make_shared<EventRecordTaskInfo>("op_name", 0, 10);
|
||||
ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, event_record_task_info));
|
||||
}
|
||||
|
||||
TEST_F(TestAscendGeRuntime, test_event_wait_task_create_success) {
|
||||
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
|
||||
{reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
|
||||
{reinterpret_cast<rtEvent_t>(1)});
|
||||
std::shared_ptr<TaskInfo> event_record_task_info = std::make_shared<EventWaitTaskInfo>("op_name", 0, 0);
|
||||
std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, event_record_task_info);
|
||||
ASSERT_TRUE(std::dynamic_pointer_cast<EventWaitTask>(task) != nullptr);
|
||||
ASSERT_NO_THROW(task->Distribute());
|
||||
}
|
||||
|
||||
TEST_F(TestAscendGeRuntime, test_event_wait_task_create_invalid_event_id_failed) {
|
||||
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
|
||||
{reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
|
||||
{reinterpret_cast<rtEvent_t>(1)});
|
||||
std::shared_ptr<TaskInfo> event_record_task_info = std::make_shared<EventWaitTaskInfo>("op_name", 0, 10);
|
||||
ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, event_record_task_info));
|
||||
}
|
||||
|
||||
TEST_F(TestAscendGeRuntime, test_hccl_task_create_success) {
|
||||
MockOpsKernelInfoStore ops_kernel_info_store;
|
||||
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
|
||||
{reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
|
||||
{reinterpret_cast<rtEvent_t>(1)});
|
||||
std::shared_ptr<TaskInfo> hccl_task_info = std::make_shared<HcclTaskInfo>(
|
||||
"op_name", 0, "hccl_type", reinterpret_cast<void *>(1), reinterpret_cast<void *>(2), reinterpret_cast<void *>(3), 4,
|
||||
5, std::vector<uint8_t>(6, 7), reinterpret_cast<void *>(&ops_kernel_info_store), 9, 10, 11, 12, "group", true);
|
||||
std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, hccl_task_info);
|
||||
ASSERT_TRUE(std::dynamic_pointer_cast<HcclTask>(task) != nullptr);
|
||||
ASSERT_NO_THROW(task->Distribute());
|
||||
}
|
||||
|
||||
TEST_F(TestAscendGeRuntime, test_hccl_task_create_stream_reuse_success) {
|
||||
const rtModel_t model = reinterpret_cast<rtModel_t>(0x12345678);
|
||||
const rtStream_t stream = reinterpret_cast<rtStream_t>(0x87654321);
|
||||
constexpr uint32_t stream_id = 0;
|
||||
constexpr int64_t task1_stream_num = 3;
|
||||
constexpr int64_t task2_stream_num = 5;
|
||||
constexpr int64_t task3_stream_num = 4;
|
||||
MockOpsKernelInfoStore ops_kernel_info_store;
|
||||
ModelContext model_context(0, 0, 0, model, reinterpret_cast<rtStream_t>(2), {stream},
|
||||
{reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
|
||||
std::shared_ptr<TaskInfo> hccl_task_info_1 = std::make_shared<HcclTaskInfo>(
|
||||
"op_name", stream_id, "hccl_type", reinterpret_cast<void *>(1), reinterpret_cast<void *>(2),
|
||||
reinterpret_cast<void *>(3), 4, task1_stream_num, std::vector<uint8_t>(6, 7),
|
||||
reinterpret_cast<void *>(&ops_kernel_info_store), 9, 10, 11, 12, "group", true);
|
||||
std::shared_ptr<TaskInfo> hccl_task_info_2 = std::make_shared<HcclTaskInfo>(
|
||||
"op_name", stream_id, "hccl_type", reinterpret_cast<void *>(1), reinterpret_cast<void *>(2),
|
||||
reinterpret_cast<void *>(3), 4, task2_stream_num, std::vector<uint8_t>(6, 7),
|
||||
reinterpret_cast<void *>(&ops_kernel_info_store), 9, 10, 11, 12, "group", true);
|
||||
std::shared_ptr<TaskInfo> hccl_task_info_3 = std::make_shared<HcclTaskInfo>(
|
||||
"op_name", stream_id, "hccl_type", reinterpret_cast<void *>(1), reinterpret_cast<void *>(2),
|
||||
reinterpret_cast<void *>(3), 4, task3_stream_num, std::vector<uint8_t>(6, 7),
|
||||
reinterpret_cast<void *>(&ops_kernel_info_store), 9, 10, 11, 12, "group", true);
|
||||
std::shared_ptr<Task> task_1 = TaskFactory::GetInstance().Create(model_context, hccl_task_info_1);
|
||||
std::shared_ptr<Task> task_2 = TaskFactory::GetInstance().Create(model_context, hccl_task_info_2);
|
||||
std::shared_ptr<Task> task_3 = TaskFactory::GetInstance().Create(model_context, hccl_task_info_3);
|
||||
ASSERT_TRUE(std::dynamic_pointer_cast<HcclTask>(task_1) != nullptr);
|
||||
ASSERT_TRUE(std::dynamic_pointer_cast<HcclTask>(task_2) != nullptr);
|
||||
ASSERT_TRUE(std::dynamic_pointer_cast<HcclTask>(task_3) != nullptr);
|
||||
ASSERT_NO_THROW(task_1->Distribute());
|
||||
ASSERT_NO_THROW(task_2->Distribute());
|
||||
ASSERT_NO_THROW(task_3->Distribute());
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(HcclTask::model_stream_mapping_mutex_);
|
||||
auto model_iter = HcclTask::model_stream_mapping_.find(model);
|
||||
ASSERT_NE(model_iter, HcclTask::model_stream_mapping_.end());
|
||||
auto stream_iter = model_iter->second.find(stream_id);
|
||||
ASSERT_NE(stream_iter, model_iter->second.end());
|
||||
const auto &stream_vec = stream_iter->second;
|
||||
ASSERT_EQ(stream_vec.size(), std::max(task1_stream_num, std::max(task2_stream_num, task3_stream_num)));
|
||||
for (const auto &s : stream_vec) {
|
||||
auto shared = s.lock();
|
||||
ASSERT_TRUE(shared != nullptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TestAscendGeRuntime, test_label_goto_task_create_success) {
|
||||
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
|
||||
{reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
|
||||
{reinterpret_cast<rtEvent_t>(1)});
|
||||
std::shared_ptr<TaskInfo> label_goto_task_info = std::make_shared<LabelGotoTaskInfo>("op_name", 0, 0);
|
||||
std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, label_goto_task_info);
|
||||
auto label_goto_task = std::dynamic_pointer_cast<LabelGotoTask>(task);
|
||||
ASSERT_TRUE(label_goto_task != nullptr);
|
||||
ASSERT_NO_THROW(task->Distribute());
|
||||
label_goto_task->index_value_ = new uint8_t[5];
|
||||
}
|
||||
|
||||
TEST_F(TestAscendGeRuntime, test_label_goto_task_create_invalid_label_id_failed) {
|
||||
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
|
||||
{reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
|
||||
{reinterpret_cast<rtEvent_t>(1)});
|
||||
std::shared_ptr<TaskInfo> label_goto_task_info = std::make_shared<LabelGotoTaskInfo>("op_name", 0, 1);
|
||||
ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, label_goto_task_info));
|
||||
}
|
||||
|
||||
TEST_F(TestAscendGeRuntime, test_label_goto_task_reuse_success) {
|
||||
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
|
||||
{reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
|
||||
{reinterpret_cast<rtEvent_t>(1)});
|
||||
std::shared_ptr<TaskInfo> label_goto_task_info = std::make_shared<LabelGotoTaskInfo>("op_name", 0, 0);
|
||||
std::shared_ptr<Task> task1 = TaskFactory::GetInstance().Create(model_context, label_goto_task_info);
|
||||
std::shared_ptr<Task> task2 = TaskFactory::GetInstance().Create(model_context, label_goto_task_info);
|
||||
auto label_goto_task_1 = std::dynamic_pointer_cast<LabelGotoTask>(task1);
|
||||
auto label_goto_task_2 = std::dynamic_pointer_cast<LabelGotoTask>(task2);
|
||||
ASSERT_TRUE(label_goto_task_1 != nullptr);
|
||||
ASSERT_NO_THROW(task1->Distribute());
|
||||
ASSERT_TRUE(label_goto_task_2 != nullptr);
|
||||
ASSERT_NO_THROW(task2->Distribute());
|
||||
ASSERT_EQ(label_goto_task_1->label_info_, label_goto_task_2->label_info_);
|
||||
}
|
||||
|
||||
TEST_F(TestAscendGeRuntime, test_label_set_task_create_success) {
|
||||
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
|
||||
{reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
|
||||
{reinterpret_cast<rtEvent_t>(1)});
|
||||
std::shared_ptr<TaskInfo> label_set_task_info = std::make_shared<LabelSetTaskInfo>("op_name", 0, 0);
|
||||
std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, label_set_task_info);
|
||||
ASSERT_TRUE(std::dynamic_pointer_cast<LabelSetTask>(task) != nullptr);
|
||||
ASSERT_NO_THROW(task->Distribute());
|
||||
}
|
||||
|
||||
TEST_F(TestAscendGeRuntime, test_label_set_task_create_invalid_label_id_failed) {
|
||||
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
|
||||
{reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
|
||||
{reinterpret_cast<rtEvent_t>(1)});
|
||||
std::shared_ptr<TaskInfo> label_set_task_info = std::make_shared<LabelGotoTaskInfo>("op_name", 0, 1);
|
||||
ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, label_set_task_info));
|
||||
}
|
||||
|
||||
TEST_F(TestAscendGeRuntime, test_label_switch_task_create_success) {
|
||||
ModelContext model_context(
|
||||
0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)},
|
||||
{reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
|
||||
std::shared_ptr<TaskInfo> label_switch_task_info =
|
||||
std::make_shared<LabelSwitchTaskInfo>("op_name", 0, 2, std::vector<uint32_t>{0, 1}, reinterpret_cast<void *>(1));
|
||||
std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, label_switch_task_info);
|
||||
ASSERT_TRUE(std::dynamic_pointer_cast<LabelSwitchTask>(task) != nullptr);
|
||||
ASSERT_NO_THROW(task->Distribute());
|
||||
}
|
||||
|
||||
TEST_F(TestAscendGeRuntime, test_label_switch_task_create_invalid_stream_id_failed) {
|
||||
ModelContext model_context(
|
||||
0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)},
|
||||
{reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
|
||||
std::shared_ptr<TaskInfo> label_switch_task_info =
|
||||
std::make_shared<LabelSwitchTaskInfo>("op_name", 1, 2, std::vector<uint32_t>{0, 1}, reinterpret_cast<void *>(1));
|
||||
ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, label_switch_task_info));
|
||||
}
|
||||
|
||||
TEST_F(TestAscendGeRuntime, test_label_switch_task_create_invalid_label_id_failed) {
|
||||
ModelContext model_context(
|
||||
0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)},
|
||||
{reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
|
||||
std::shared_ptr<TaskInfo> label_switch_task_info =
|
||||
std::make_shared<LabelSwitchTaskInfo>("op_name", 0, 3, std::vector<uint32_t>{0, 1, 2}, reinterpret_cast<void *>(1));
|
||||
ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, label_switch_task_info));
|
||||
}
|
||||
|
||||
TEST_F(TestAscendGeRuntime, test_label_switch_task_reuse_success) {
|
||||
ModelContext model_context(
|
||||
0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)},
|
||||
{reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
|
||||
std::shared_ptr<TaskInfo> label_switch_task_info =
|
||||
std::make_shared<LabelSwitchTaskInfo>("op_name", 0, 2, std::vector<uint32_t>{0, 1}, reinterpret_cast<void *>(1));
|
||||
std::shared_ptr<Task> task1 = TaskFactory::GetInstance().Create(model_context, label_switch_task_info);
|
||||
std::shared_ptr<Task> task2 = TaskFactory::GetInstance().Create(model_context, label_switch_task_info);
|
||||
auto label_switch_task_1 = std::dynamic_pointer_cast<LabelSwitchTask>(task1);
|
||||
auto label_switch_task_2 = std::dynamic_pointer_cast<LabelSwitchTask>(task2);
|
||||
ASSERT_TRUE(label_switch_task_1 != nullptr);
|
||||
ASSERT_TRUE(label_switch_task_2 != nullptr);
|
||||
ASSERT_NO_THROW(task1->Distribute());
|
||||
ASSERT_NO_THROW(task2->Distribute());
|
||||
ASSERT_EQ(label_switch_task_1->label_info_, label_switch_task_2->label_info_);
|
||||
}
|
||||
|
||||
TEST_F(TestAscendGeRuntime, test_memcpy_async_task_create_success) {
|
||||
ModelContext model_context(
|
||||
0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)},
|
||||
{reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
|
||||
std::shared_ptr<TaskInfo> memcpy_task_info = std::make_shared<MemcpyAsyncTaskInfo>(
|
||||
"op_name", 0, reinterpret_cast<void *>(1), 2, reinterpret_cast<void *>(3), 4, 5, true);
|
||||
std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, memcpy_task_info);
|
||||
ASSERT_TRUE(std::dynamic_pointer_cast<MemcpyAsyncTask>(task) != nullptr);
|
||||
ASSERT_NO_THROW(task->Distribute());
|
||||
}
|
||||
|
||||
TEST_F(TestAscendGeRuntime, test_memcpy_async_task_create_invalid_stream_id_failed) {
|
||||
ModelContext model_context(
|
||||
0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)},
|
||||
{reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
|
||||
std::shared_ptr<TaskInfo> memcpy_task_info = std::make_shared<MemcpyAsyncTaskInfo>(
|
||||
"op_name", 1, reinterpret_cast<void *>(1), 2, reinterpret_cast<void *>(3), 4, 5, true);
|
||||
ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, memcpy_task_info));
|
||||
}
|
||||
|
||||
TEST_F(TestAscendGeRuntime, test_profiler_task_create_success) {
|
||||
ModelContext model_context(
|
||||
0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)},
|
||||
{reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
|
||||
std::shared_ptr<TaskInfo> profiler_task_info = std::make_shared<ProfilerTraceTaskInfo>("op_name", 0, 1, true, 2);
|
||||
std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, profiler_task_info);
|
||||
ASSERT_TRUE(std::dynamic_pointer_cast<ProfilerTask>(task) != nullptr);
|
||||
ASSERT_NO_THROW(task->Distribute());
|
||||
}
|
||||
|
||||
TEST_F(TestAscendGeRuntime, test_profiler_task_create_invalid_stream_id_failed) {
|
||||
ModelContext model_context(
|
||||
0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)},
|
||||
{reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
|
||||
std::shared_ptr<TaskInfo> profiler_task_info = std::make_shared<ProfilerTraceTaskInfo>("op_name", 1, 1, true, 2);
|
||||
ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, profiler_task_info));
|
||||
}
|
||||
|
||||
TEST_F(TestAscendGeRuntime, test_stream_active_task_create_success) {
|
||||
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
|
||||
{reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)},
|
||||
{reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)},
|
||||
{reinterpret_cast<rtEvent_t>(1)});
|
||||
std::shared_ptr<TaskInfo> stream_active_task_info = std::make_shared<StreamActiveTaskInfo>("op_name", 0, 1);
|
||||
std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, stream_active_task_info);
|
||||
ASSERT_TRUE(std::dynamic_pointer_cast<StreamActiveTask>(task) != nullptr);
|
||||
ASSERT_NO_THROW(task->Distribute());
|
||||
}
|
||||
|
||||
TEST_F(TestAscendGeRuntime, test_stream_active_task_create_invalid_active_stream_id_failed) {
|
||||
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
|
||||
{reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)},
|
||||
{reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)},
|
||||
{reinterpret_cast<rtEvent_t>(1)});
|
||||
std::shared_ptr<TaskInfo> stream_active_task_info = std::make_shared<StreamActiveTaskInfo>("op_name", 0, 2);
|
||||
ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, stream_active_task_info));
|
||||
}
|
||||
|
||||
TEST_F(TestAscendGeRuntime, test_stream_switch_task_create_success) {
|
||||
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
|
||||
{reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)},
|
||||
{reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)},
|
||||
{reinterpret_cast<rtEvent_t>(1)});
|
||||
std::shared_ptr<TaskInfo> stream_switch_task_info = std::make_shared<StreamSwitchTaskInfo>(
|
||||
"op_name", 0, 1, reinterpret_cast<void *>(2), reinterpret_cast<void *>(3), 4, 5);
|
||||
std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, stream_switch_task_info);
|
||||
ASSERT_TRUE(std::dynamic_pointer_cast<StreamSwitchTask>(task) != nullptr);
|
||||
ASSERT_NO_THROW(task->Distribute());
|
||||
}
|
||||
|
||||
TEST_F(TestAscendGeRuntime, test_stream_switch_task_create_invalid_true_stream_id_failed) {
|
||||
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
|
||||
{reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)},
|
||||
{reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)},
|
||||
{reinterpret_cast<rtEvent_t>(1)});
|
||||
std::shared_ptr<TaskInfo> stream_switch_task_info = std::make_shared<StreamSwitchTaskInfo>(
|
||||
"op_name", 0, 2, reinterpret_cast<void *>(2), reinterpret_cast<void *>(3), 4, 5);
|
||||
std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, stream_switch_task_info);
|
||||
ASSERT_TRUE(std::dynamic_pointer_cast<StreamSwitchTask>(task) != nullptr);
|
||||
ASSERT_ANY_THROW(task->Distribute());
|
||||
}
|
||||
|
||||
TEST_F(TestAscendGeRuntime, test_stream_switch_task_create_invalid_stream_id_failed) {
|
||||
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
|
||||
{reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)},
|
||||
{reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)},
|
||||
{reinterpret_cast<rtEvent_t>(1)});
|
||||
std::shared_ptr<TaskInfo> stream_switch_task_info = std::make_shared<StreamSwitchTaskInfo>(
|
||||
"op_name", 2, 1, reinterpret_cast<void *>(2), reinterpret_cast<void *>(3), 4, 5);
|
||||
ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, stream_switch_task_info));
|
||||
}
|
||||
|
||||
TEST_F(TestAscendGeRuntime, test_tbe_task_create_success) {
|
||||
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
|
||||
{reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)},
|
||||
{reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)},
|
||||
{reinterpret_cast<rtEvent_t>(1)});
|
||||
std::shared_ptr<TaskInfo> tbe_task_info = std::make_shared<TbeTaskInfo>(
|
||||
"op_name", 0, "stub_func", 1, std::vector<uint8_t>(100, 2), 100, std::vector<uint8_t>{5, 6},
|
||||
reinterpret_cast<void *>(7), 8, std::vector<uint8_t>{9, 10},
|
||||
std::vector<void *>{reinterpret_cast<void *>(11), reinterpret_cast<void *>(12)},
|
||||
std::vector<void *>{reinterpret_cast<void *>(13), reinterpret_cast<void *>(14)},
|
||||
std::vector<void *>{reinterpret_cast<void *>(15), reinterpret_cast<void *>(16)}, true);
|
||||
std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, tbe_task_info);
|
||||
auto tbe_task = std::dynamic_pointer_cast<TbeTask>(task);
|
||||
ASSERT_TRUE(tbe_task != nullptr);
|
||||
ASSERT_NO_THROW(task->Distribute());
|
||||
tbe_task->args_ = new uint8_t[5];
|
||||
}
|
||||
|
||||
TEST_F(TestAscendGeRuntime, test_tbe_task_create_invalid_stream_id_failed) {
|
||||
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
|
||||
{reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)},
|
||||
{reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)},
|
||||
{reinterpret_cast<rtEvent_t>(1)});
|
||||
std::shared_ptr<TaskInfo> tbe_task_info = std::make_shared<TbeTaskInfo>(
|
||||
"op_name", 3, "stub_func", 1, std::vector<uint8_t>(100, 2), 100, std::vector<uint8_t>{5, 6},
|
||||
reinterpret_cast<void *>(7), 8, std::vector<uint8_t>{9, 10},
|
||||
std::vector<void *>{reinterpret_cast<void *>(11), reinterpret_cast<void *>(12)},
|
||||
std::vector<void *>{reinterpret_cast<void *>(13), reinterpret_cast<void *>(14)},
|
||||
std::vector<void *>{reinterpret_cast<void *>(15), reinterpret_cast<void *>(16)}, true);
|
||||
ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, tbe_task_info));
|
||||
}
|
||||
|
||||
TEST_F(TestAscendGeRuntime, test_tbe_task_create_empty_stub_func_failed) {
|
||||
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
|
||||
{reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)},
|
||||
{reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)},
|
||||
{reinterpret_cast<rtEvent_t>(1)});
|
||||
std::shared_ptr<TaskInfo> tbe_task_info = std::make_shared<TbeTaskInfo>(
|
||||
"op_name", 0, "", 1, std::vector<uint8_t>(100, 2), 100, std::vector<uint8_t>{5, 6}, reinterpret_cast<void *>(7), 8,
|
||||
std::vector<uint8_t>{9, 10}, std::vector<void *>{reinterpret_cast<void *>(11), reinterpret_cast<void *>(12)},
|
||||
std::vector<void *>{reinterpret_cast<void *>(13), reinterpret_cast<void *>(14)},
|
||||
std::vector<void *>{reinterpret_cast<void *>(15), reinterpret_cast<void *>(16)}, true);
|
||||
std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, tbe_task_info);
|
||||
ASSERT_TRUE(std::dynamic_pointer_cast<TbeTask>(task) != nullptr);
|
||||
ASSERT_ANY_THROW(task->Distribute());
|
||||
}
|
||||
|
||||
TEST_F(TestAscendGeRuntime, test_model_runner_success) {
|
||||
constexpr uint32_t model_id = 0;
|
||||
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
|
||||
{reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)},
|
||||
{reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)},
|
||||
{reinterpret_cast<rtEvent_t>(1)});
|
||||
std::shared_ptr<TaskInfo> tbe_task_info = std::make_shared<TbeTaskInfo>(
|
||||
"op_name", 0, "stub_func", 1, std::vector<uint8_t>(100, 2), 100, std::vector<uint8_t>{5, 6},
|
||||
reinterpret_cast<void *>(7), 8, std::vector<uint8_t>{9, 10},
|
||||
std::vector<void *>{reinterpret_cast<void *>(11), reinterpret_cast<void *>(12)},
|
||||
std::vector<void *>{reinterpret_cast<void *>(13), reinterpret_cast<void *>(14)},
|
||||
std::vector<void *>{reinterpret_cast<void *>(15), reinterpret_cast<void *>(16)}, true);
|
||||
std::shared_ptr<TaskInfo> aicpu_task_info = std::make_shared<AicpuTaskInfo>(
|
||||
"op_name", 0, "so_name", "kernel_name", "node_def", "ext_info", std::vector<void *>{reinterpret_cast<void *>(1)},
|
||||
std::vector<void *>{reinterpret_cast<void *>(1)}, true);
|
||||
auto davice_model =
|
||||
std::make_shared<DavinciModel>(std::vector<std::shared_ptr<TaskInfo>>{tbe_task_info, aicpu_task_info},
|
||||
std::vector<uint32_t>{}, std::vector<uint32_t>{}, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0);
|
||||
ASSERT_NO_THROW(ModelRunner::Instance().LoadDavinciModel(0, 0, model_id, davice_model));
|
||||
auto iter = ModelRunner::Instance().runtime_models_.find(model_id);
|
||||
ASSERT_TRUE(iter != ModelRunner::Instance().runtime_models_.end());
|
||||
auto &task_list = iter->second->task_list_;
|
||||
task_list.clear();
|
||||
ASSERT_NO_THROW(task_list.emplace_back(TaskFactory::GetInstance().Create(model_context, tbe_task_info)));
|
||||
ASSERT_NO_THROW(task_list.emplace_back(TaskFactory::GetInstance().Create(model_context, aicpu_task_info)));
|
||||
ASSERT_NO_THROW(ModelRunner::Instance().DistributeTask(model_id));
|
||||
ASSERT_NO_THROW(ModelRunner::Instance().LoadModelComplete(model_id));
|
||||
ASSERT_NO_THROW(ModelRunner::Instance().RunModel(model_id));
|
||||
ASSERT_FALSE(ModelRunner::Instance().GetTaskIdList(model_id).empty());
|
||||
ASSERT_FALSE(ModelRunner::Instance().GetStreamIdList(model_id).empty());
|
||||
ASSERT_FALSE(ModelRunner::Instance().GetRuntimeInfoMap(model_id).empty());
|
||||
ASSERT_NO_THROW(ModelRunner::Instance().GetModelHandle(model_id));
|
||||
ASSERT_NO_THROW(ModelRunner::Instance().UnloadModel(model_id));
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -14,51 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include <vector>
|
||||
#include "framework/ge_runtime/model_runner.h"
|
||||
#include "runtime/hccl_adapter/hccl_adapter.h"
|
||||
|
||||
namespace ge {
|
||||
namespace model_runner {
|
||||
ModelRunner &ModelRunner::Instance() {
|
||||
static ModelRunner runner;
|
||||
return runner;
|
||||
}
|
||||
|
||||
bool ModelRunner::LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint32_t model_id,
|
||||
std::shared_ptr<DavinciModel> ascend_model,
|
||||
std::shared_ptr<ge::ModelListener> listener) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ModelRunner::UnloadModel(uint32_t model_id) { return true; }
|
||||
|
||||
bool ModelRunner::LoadModelComplete(uint32_t model_id) { return true; }
|
||||
|
||||
bool ModelRunner::RunModel(uint32_t model_id, const ge::InputData &input_data, ge::OutputData *output_data) {
|
||||
return true;
|
||||
}
|
||||
|
||||
void *ModelRunner::GetModelHandle(uint32_t model_id) const { return nullptr; }
|
||||
|
||||
bool ModelRunner::DistributeTask(uint32_t model_id) { return true; }
|
||||
|
||||
const std::vector<uint32_t> &ModelRunner::GetTaskIdList(uint32_t model_id) const {
|
||||
static std::vector<uint32_t> task_id_list;
|
||||
return task_id_list;
|
||||
}
|
||||
|
||||
const std::vector<uint32_t> &ModelRunner::GetStreamIdList(uint32_t model_id) const {
|
||||
static std::vector<uint32_t> stream_id_list;
|
||||
return stream_id_list;
|
||||
}
|
||||
|
||||
const std::map<std::string, std::shared_ptr<RuntimeInfo>> &ModelRunner::GetRuntimeInfoMap(uint32_t model_id) const {
|
||||
static std::map<std::string, std::shared_ptr<RuntimeInfo>> runtime_info_map;
|
||||
return runtime_info_map;
|
||||
}
|
||||
} // namespace model_runner
|
||||
} // namespace ge
|
||||
|
||||
namespace mindspore {
|
||||
namespace hccl {
|
||||
bool InitHccl(uint32_t, std::string_view, std::string_view) { return true; }
|
||||
|
|
|
@ -141,9 +141,9 @@ rtError_t rtGetFunctionByName(const char *stubName, void **stubFunc) { return RT
|
|||
|
||||
rtError_t rtSetTaskGenCallback(rtTaskGenCallback callback) { return RT_ERROR_NONE; }
|
||||
|
||||
RTS_API rtError_t rtProfilerStart(uint64_t profConfig, int32_t numsDev, uint32_t* deviceList) { return RT_ERROR_NONE; }
|
||||
RTS_API rtError_t rtProfilerStart(uint64_t profConfig, int32_t numsDev, uint32_t *deviceList) { return RT_ERROR_NONE; }
|
||||
|
||||
RTS_API rtError_t rtProfilerStop(uint64_t profConfig, int32_t numsDev, uint32_t* deviceList) { return RT_ERROR_NONE; }
|
||||
RTS_API rtError_t rtProfilerStop(uint64_t profConfig, int32_t numsDev, uint32_t *deviceList) { return RT_ERROR_NONE; }
|
||||
|
||||
int AdxDataDumpServerInit() { return 0; }
|
||||
|
||||
|
@ -151,11 +151,13 @@ int AdxDataDumpServerUnInit() { return 0; }
|
|||
|
||||
RTS_API rtError_t rtGetTaskIdAndStreamID(uint32_t *taskid, uint32_t *streamid) { return RT_ERROR_NONE; }
|
||||
|
||||
RTS_API rtError_t rtSetTaskFailCallback(rtTaskFailCallback callback) {return RT_ERROR_NONE; }
|
||||
RTS_API rtError_t rtSetTaskFailCallback(rtTaskFailCallback callback) { return RT_ERROR_NONE; }
|
||||
|
||||
RTS_API rtError_t rtRegDeviceStateCallback(const char *regName, rtDeviceStateCallback callback) {return RT_ERROR_NONE; }
|
||||
RTS_API rtError_t rtRegDeviceStateCallback(const char *regName, rtDeviceStateCallback callback) {
|
||||
return RT_ERROR_NONE;
|
||||
}
|
||||
|
||||
RTS_API rtError_t rtSetMsprofReporterCallback(MsprofReporterCallback callback) {return RT_ERROR_NONE; }
|
||||
RTS_API rtError_t rtSetMsprofReporterCallback(MsprofReporterCallback callback) { return RT_ERROR_NONE; }
|
||||
|
||||
RTS_API rtError_t rtRegTaskFailCallbackByModule(const char *moduleName, rtTaskFailCallback callback) {
|
||||
return RT_ERROR_NONE;
|
||||
|
@ -168,3 +170,28 @@ RTS_API rtError_t rtDevBinaryUnRegister(void *handle) { return RT_ERROR_NONE; }
|
|||
RTS_API rtError_t rtMemsetAsync(void *ptr, uint64_t destMax, uint32_t value, uint64_t count, rtStream_t stream) {
|
||||
return RT_ERROR_NONE;
|
||||
}
|
||||
|
||||
RTS_API rtError_t rtLabelListCpy(rtLabel_t *label, uint32_t labelNumber, void *dst, uint32_t dstMax) {
|
||||
return RT_ERROR_NONE;
|
||||
}
|
||||
|
||||
RTS_API rtError_t rtModelGetTaskId(rtModel_t model, uint32_t *taskid, uint32_t *streamid) { return RT_ERROR_NONE; }
|
||||
|
||||
RTS_API rtError_t rtLabelCreateEx(rtLabel_t *label, rtStream_t stream) { return RT_ERROR_NONE; }
|
||||
|
||||
RTS_API rtError_t rtCpuKernelLaunchWithFlag(const void *soName, const void *kernelName, uint32_t blockDim,
|
||||
const void *args, uint32_t argsSize, rtSmDesc_t *smDesc, rtStream_t stream,
|
||||
uint32_t flags) {
|
||||
return RT_ERROR_NONE;
|
||||
}
|
||||
|
||||
RTS_API rtError_t rtLabelSwitchByIndex(void *ptr, uint32_t max, void *labelInfoPtr, rtStream_t stream) {
|
||||
return RT_ERROR_NONE;
|
||||
}
|
||||
|
||||
RTS_API rtError_t rtProfilerTrace(uint64_t id, bool notify, uint32_t flags, rtStream_t stream) { return RT_ERROR_NONE; }
|
||||
|
||||
RTS_API rtError_t rtKernelLaunchWithFlag(const void *stubFunc, uint32_t blockDim, void *args, uint32_t argsSize,
|
||||
rtSmDesc_t *smDesc, rtStream_t stream, uint32_t flags) {
|
||||
return RT_ERROR_NONE;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue