!15107 use internal ge runtime

From: @zhoufeng54
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-04-29 10:05:54 +08:00 committed by Gitee
commit 966f89198a
65 changed files with 3495 additions and 139 deletions

View File

@ -252,7 +252,6 @@ if(NOT ENABLE_GE)
FILES
${CMAKE_BINARY_DIR}/graphengine/metadef/graph/libgraph.so
${CMAKE_BINARY_DIR}/graphengine/ge/common/libge_common.so
${CMAKE_BINARY_DIR}/graphengine/ge/ge_runtime/libge_runtime.so
DESTINATION ${INSTALL_LIB_DIR}
COMPONENT mindspore
)

View File

@ -313,7 +313,7 @@ if(ENABLE_D)
target_link_options(ms_profile PRIVATE -Wl,-init,common_log_init)
target_link_libraries(ms_profile -Wl,--start-group -Wl,--whole-archive ${PROFILING} -Wl,--no-whole-archive
mindspore::protobuf -Wl,--end-group)
target_link_libraries(mindspore ge_runtime ${CCE_LIB} ${RUNTIME_LIB} ${TSDCLIENT} ${HCCL} ${DATATRANSFER}
target_link_libraries(mindspore ${CCE_LIB} ${RUNTIME_LIB} ${TSDCLIENT} ${HCCL} ${DATATRANSFER}
${HCCL_ADPTER} ${REGISTER} -Wl,--no-as-needed ${OPTILING} ${HCCL_BUILDER}
${HCCL_RA} ${PLATFORM} ${ACL})
target_link_libraries(mindspore -Wl,--start-group proto_input mindspore::protobuf -Wl,--end-group)

View File

@ -30,7 +30,7 @@
#include "runtime/device/kernel_runtime.h"
#include "runtime/device/ascend/executor/host_dynamic_kernel.h"
using AicpuTaskInfoPtr = std::shared_ptr<ge::model_runner::AicpuTaskInfo>;
using AicpuTaskInfoPtr = std::shared_ptr<mindspore::ge::model_runner::AicpuTaskInfo>;
using AicpuDynamicKernel = mindspore::device::ascend::AiCpuDynamicKernel;
using HostDynamicKernel = mindspore::device::ascend::HostDynamicKernel;
@ -193,9 +193,9 @@ std::vector<TaskInfoPtr> AicpuOpKernelMod::GenTask(const std::vector<AddressPtr>
node_name_ = kPack;
}
AicpuTaskInfoPtr task_info_ptr =
make_shared<ge::model_runner::AicpuTaskInfo>(kernel_name_, stream_id, node_so_, node_name_, node_def_str_,
ext_info_, input_data_addrs, output_data_addrs, NeedDump());
AicpuTaskInfoPtr task_info_ptr = std::make_shared<mindspore::ge::model_runner::AicpuTaskInfo>(
kernel_name_, stream_id, node_so_, node_name_, node_def_str_, ext_info_, input_data_addrs, output_data_addrs,
NeedDump());
MS_LOG(INFO) << "AicpuOpKernelMod GenTask end";
return {task_info_ptr};

View File

@ -29,7 +29,7 @@ using std::fstream;
using std::map;
using std::mutex;
using std::string;
using TbeTaskInfoPtr = std::shared_ptr<ge::model_runner::TbeTaskInfo>;
using TbeTaskInfoPtr = std::shared_ptr<mindspore::ge::model_runner::TbeTaskInfo>;
using tbe::KernelManager;
constexpr uint32_t DEFAULT_BLOCK_DIM = 1;
/**
@ -118,7 +118,7 @@ std::vector<TaskInfoPtr> AkgKernelMod::GenTask(const std::vector<AddressPtr> &in
MS_LOG(DEBUG) << "The block_dim is:" << block_dim;
TbeTaskInfoPtr task_info_ptr = make_shared<ge::model_runner::TbeTaskInfo>(
TbeTaskInfoPtr task_info_ptr = std::make_shared<mindspore::ge::model_runner::TbeTaskInfo>(
kernel_name_, stream_id, stub_func, block_dim, args, args_size, sm_desc, binary, binary_size, meta_data,
input_data_addrs, output_data_addrs, workspace_addrs, NeedDump());
return {task_info_ptr};

View File

@ -19,11 +19,11 @@
#include <vector>
#include <memory>
#include "framework/ge_runtime/task_info.h"
#include "runtime/device/ascend/ge_runtime/task_info.h"
#include "backend/kernel_compiler/kernel.h"
#include "debug/data_dump/dump_json_parser.h"
using TaskInfoPtr = std::shared_ptr<ge::model_runner::TaskInfo>;
using TaskInfoPtr = std::shared_ptr<mindspore::ge::model_runner::TaskInfo>;
namespace mindspore {
namespace kernel {
class AscendKernelMod : public KernelMod {

View File

@ -24,8 +24,8 @@
#include "runtime/device/ascend/executor/hccl_dynamic_kernel.h"
#include "runtime/hccl_adapter/hccl_adapter.h"
using HcclTaskInfoPtr = std::shared_ptr<ge::model_runner::HcclTaskInfo>;
using ge::model_runner::HcclTaskInfo;
using HcclTaskInfoPtr = std::shared_ptr<mindspore::ge::model_runner::HcclTaskInfo>;
using mindspore::ge::model_runner::HcclTaskInfo;
namespace {
static std::map<std::string, std::string> kMsOpNameToHcomHcclType = {

View File

@ -18,7 +18,7 @@
#include <memory>
#include "runtime/mem.h"
using ge::model_runner::MemcpyAsyncTaskInfo;
using mindspore::ge::model_runner::MemcpyAsyncTaskInfo;
using MemcpyAsyncTaskInfoPtr = std::shared_ptr<MemcpyAsyncTaskInfo>;
namespace mindspore {

View File

@ -20,7 +20,7 @@
#include "framework/ge_runtime/task_info.h"
#include "backend/session/anf_runtime_algorithm.h"
using ge::model_runner::LabelGotoTaskInfo;
using mindspore::ge::model_runner::LabelGotoTaskInfo;
using LabelGotoTaskInfoPtr = std::shared_ptr<LabelGotoTaskInfo>;
namespace mindspore {

View File

@ -20,7 +20,7 @@
#include "framework/ge_runtime/task_info.h"
#include "backend/session/anf_runtime_algorithm.h"
using ge::model_runner::LabelSetTaskInfo;
using mindspore::ge::model_runner::LabelSetTaskInfo;
using LabelSetTaskInfoPtr = std::shared_ptr<LabelSetTaskInfo>;
namespace mindspore {

View File

@ -21,7 +21,7 @@
#include "framework/ge_runtime/task_info.h"
#include "backend/session/anf_runtime_algorithm.h"
using ge::model_runner::LabelSwitchTaskInfo;
using mindspore::ge::model_runner::LabelSwitchTaskInfo;
using LabelSwitchTaskInfoPtr = std::shared_ptr<LabelSwitchTaskInfo>;
namespace mindspore {

View File

@ -25,7 +25,7 @@
#include "runtime/device/kernel_runtime.h"
#include "runtime/device/ascend/executor/rts/memcpy_rts_dynamic_kernel.h"
using ge::model_runner::MemcpyAsyncTaskInfo;
using mindspore::ge::model_runner::MemcpyAsyncTaskInfo;
using MemcpyAsyncTaskInfoPtr = std::shared_ptr<MemcpyAsyncTaskInfo>;
using AddressPtrList = std::vector<mindspore::kernel::AddressPtr>;
using mindspore::device::ascend::MemcpyRtsDynamicKernel;

View File

@ -23,7 +23,7 @@
#include "backend/session/anf_runtime_algorithm.h"
#include "runtime/device/ascend/executor/rts/profiling_rts_dynamic_kernel.h"
using ProfilerTraceTaskInfo = ge::model_runner::ProfilerTraceTaskInfo;
using ProfilerTraceTaskInfo = mindspore::ge::model_runner::ProfilerTraceTaskInfo;
using mindspore::device::ascend::ProfilingRtsDynamicKernel;
using mindspore::device::ascend::ProfilingUtils;

View File

@ -23,7 +23,7 @@
namespace mindspore {
namespace kernel {
using ge::model_runner::EventWaitTaskInfo;
using mindspore::ge::model_runner::EventWaitTaskInfo;
using EventWaitTaskInfoPtr = std::shared_ptr<EventWaitTaskInfo>;
RecvKernel::RecvKernel() { event_id_ = 0; }

View File

@ -20,7 +20,7 @@
#include "framework/ge_runtime/task_info.h"
#include "backend/session/anf_runtime_algorithm.h"
using ge::model_runner::EventRecordTaskInfo;
using mindspore::ge::model_runner::EventRecordTaskInfo;
using EventRecordTaskInfoPtr = std::shared_ptr<EventRecordTaskInfo>;
namespace mindspore {

View File

@ -20,7 +20,7 @@
#include "framework/ge_runtime/task_info.h"
#include "backend/session/anf_runtime_algorithm.h"
using ge::model_runner::StreamActiveTaskInfo;
using mindspore::ge::model_runner::StreamActiveTaskInfo;
using StreamActiveTaskInfoPtr = std::shared_ptr<StreamActiveTaskInfo>;
namespace mindspore {

View File

@ -21,7 +21,7 @@
#include "framework/ge_runtime/task_info.h"
#include "backend/session/anf_runtime_algorithm.h"
using ge::model_runner::StreamSwitchTaskInfo;
using mindspore::ge::model_runner::StreamSwitchTaskInfo;
using StreamSwitchTaskInfoPtr = std::shared_ptr<StreamSwitchTaskInfo>;
namespace mindspore {

View File

@ -24,7 +24,7 @@
namespace mindspore {
namespace kernel {
using TbeTaskInfoPtr = std::shared_ptr<ge::model_runner::TbeTaskInfo>;
using TbeTaskInfoPtr = std::shared_ptr<mindspore::ge::model_runner::TbeTaskInfo>;
using tbe::KernelManager;
using AddressPtrList = std::vector<mindspore::kernel::AddressPtr>;
bool TbeKernelMod::Launch(const std::vector<mindspore::kernel::AddressPtr> &inputs,
@ -102,7 +102,7 @@ std::vector<TaskInfoPtr> TbeKernelMod::GenTask(const std::vector<AddressPtr> &in
MS_LOG(INFO) << "block_dim is:" << block_dim_;
TbeTaskInfoPtr task_info_ptr = make_shared<ge::model_runner::TbeTaskInfo>(
TbeTaskInfoPtr task_info_ptr = std::make_shared<mindspore::ge::model_runner::TbeTaskInfo>(
kernel_name_, stream_id, stub_func, block_dim_, args, 0, sm_desc, nullptr, 0, meta_data, input_data_addrs,
output_data_addrs, workspace_addrs, NeedDump());
return {task_info_ptr};

View File

@ -36,7 +36,7 @@ using mindspore::kernel::tbe::TbeUtils;
bool TbeOpParallelBuild(const std::vector<AnfNodePtr> &anf_nodes) {
auto build_manger = std::make_shared<ParallelBuildManager>();
MS_EXCEPTION_IF_NULL(build_manger);
static set<std::string> processed_kernel;
static std::set<std::string> processed_kernel;
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
auto tune_mode = context_ptr->get_param<std::string>(MS_CTX_TUNE_MODE);
@ -259,8 +259,8 @@ bool ParallelBuildManager::SearchInCache(const std::string &json_name, const std
}
KernelModPtr ParallelBuildManager::GenKernelMod(const string &json_name, const string &processor,
const vector<size_t> &input_size_list,
const vector<size_t> &output_size_list,
const std::vector<size_t> &input_size_list,
const std::vector<size_t> &output_size_list,
const mindspore::kernel::KernelPackPtr &kernel_pack) const {
MS_EXCEPTION_IF_NULL(kernel_pack);
auto kernel_json_info = kernel_pack->kernel_json_info();

View File

@ -27,6 +27,7 @@
#include "proto/tensor_shape.pb.h"
#include "proto/attr.pb.h"
#include "proto/node_def.pb.h"
#include "runtime/rt.h"
using mindspore::kernel::Address;
using AddressPtr = std::shared_ptr<Address>;

View File

@ -24,6 +24,7 @@
#include "ps/ps_cache/ps_cache_basic.h"
#include "backend/kernel_compiler/aicpu/aicpu_kernel_mod.h"
#include "ir/dtype.h"
#include "runtime/base.h"
namespace mindspore {
namespace ps {

View File

@ -79,3 +79,9 @@ list(REMOVE_ITEM D_SRC_LIST "ascend/profiling/profiling_callback_register.cc")
set_property(SOURCE ${DEVICE_SRC_LIST} ${D_SRC_LIST} ${CPU_SRC_LIST}
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE)
add_library(_mindspore_runtime_device_obj OBJECT ${DEVICE_SRC_LIST} ${D_SRC_LIST} ${CPU_SRC_LIST} ${TDT_SRC_LIST})
if(ENABLE_D)
file(GLOB_RECURSE GE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ascend/ge_runtime/*.cc")
set_property(SOURCE ${GE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_GE)
target_include_directories(_mindspore_runtime_device_obj PRIVATE ${CMAKE_BINARY_DIR}/proto/ge)
add_dependencies(_mindspore_runtime_device_obj graph)
endif()

View File

@ -28,9 +28,9 @@
#include "utils/mpi/mpi_config.h"
#include "runtime/device/ascend/profiling/profiling_manager.h"
#include "common/trans.h"
#include "runtime/context.h"
#include "runtime/rt.h"
#include "runtime/device/ascend/ascend_stream_assign.h"
#include "framework/ge_runtime/model_runner.h"
#include "runtime/device/ascend/ge_runtime/model_runner.h"
#include "runtime/device/ascend/tasksink/task_generator.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "runtime/device/ascend/profiling/profiling_utils.h"
@ -40,7 +40,6 @@
#include "toolchain/adx_datadump_server.h"
#include "utils/trace_base.h"
#include "graphengine/inc/external/acl/error_codes/rt_error_codes.h"
#include "utils/runtime_error_codes.h"
#include "debug/anf_ir_dump.h"
#ifdef MEM_REUSE_DEBUG
#include "backend/optimizer/mem_reuse/mem_reuse_checker.h"
@ -61,10 +60,10 @@ using mindspore::dataset::TdtHandle;
#include "debug/rdr/running_data_recorder.h"
#endif
using ge::model_runner::ModelRunner;
using mindspore::device::ascend::ProfilingManager;
using mindspore::device::ascend::ProfilingUtils;
using mindspore::device::ascend::tasksink::TaskGenerator;
using mindspore::ge::model_runner::ModelRunner;
using mindspore::kernel::tbe::TbeUtils;
using std::vector;
@ -158,10 +157,7 @@ void AscendKernelRuntime::ClearGraphModelMap() {
graph_kernel_events_map_.clear();
for (auto &iter : graph_model_map_) {
MS_LOG(INFO) << "Ge UnloadModel " << iter.first;
auto ret = ModelRunner::Instance().UnloadModel(iter.first);
if (!ret) {
MS_LOG(ERROR) << "UnloadModel failed";
}
ModelRunner::Instance().UnloadModel(iter.first);
}
}
@ -194,10 +190,7 @@ void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std
MS_LOG(DEBUG) << "Clear graph:" << graph_id << " runtime resource";
if (auto model_iter = graph_model_map_.find(graph_id); model_iter != graph_model_map_.end()) {
MS_LOG(DEBUG) << "Ge UnloadModel " << graph_id;
auto ret = ModelRunner::Instance().UnloadModel(graph_id);
if (!ret) {
MS_LOG(ERROR) << "UnloadModel failed";
}
ModelRunner::Instance().UnloadModel(graph_id);
graph_model_map_.erase(model_iter);
} else {
MS_LOG(DEBUG) << "GraphId:" << graph_id << " not found";
@ -482,10 +475,9 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
<< ", total label num:" << graph->label_num()
<< ", wait_active_stream_list size:" << wait_active_stream_list.size()
<< ", force_copy_stream_list size:" << force_copy_stream_list.size();
std::vector<std::shared_ptr<ge::model_runner::OpInfo>> empty_list;
auto model = std::make_shared<ge::model_runner::DavinciModel>(
task_info_list, empty_list, empty_list, empty_list, empty_list, wait_active_stream_list, force_copy_stream_list, 0,
0, 0, 0, 0, 0, resource_manager.get_cur_stream_num(), graph->label_num(), resource_manager.get_cur_event_num(), 0);
task_info_list, wait_active_stream_list, force_copy_stream_list, 0, 0, 0, 0, 0, 0,
resource_manager.get_cur_stream_num(), graph->label_num(), resource_manager.get_cur_event_num(), 0);
auto ret = graph_model_map_.insert(std::make_pair(graph->graph_id(), model));
if (!ret.second) {
MS_LOG(EXCEPTION) << "Duplicate GraphId! Please check in ascend_session.";
@ -514,24 +506,20 @@ bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) {
return false;
}
std::shared_ptr<ge::ModelListener> listener;
MS_LOG(INFO) << "LoadDavinciModel mode_id:" << model_iter->first;
bool status =
ModelRunner::Instance().LoadDavinciModel(device_id_, 0, model_iter->first, model_iter->second, listener);
if (!status) {
MS_LOG(EXCEPTION) << "Load Model Failed";
}
ModelRunner::Instance().LoadDavinciModel(device_id_, 0, model_iter->first, model_iter->second);
std::function<void *()> model_handle =
std::bind(&ModelRunner::GetModelHandle, &ModelRunner::Instance(), model_iter->first);
DistributeDebugTask(NOT_NULL(graph), NOT_NULL(model_handle));
status = ModelRunner::Instance().DistributeTask(model_iter->first);
if (!status) {
try {
ModelRunner::Instance().DistributeTask(model_iter->first);
} catch (const std::exception &e) {
#ifdef ENABLE_DUMP_IR
mindspore::RDR::TriggerAll();
#endif
MS_LOG(EXCEPTION) << "Distribute Task Failed";
MS_LOG(EXCEPTION) << "Distribute Task Failed, error: " << e.what();
}
if (ProfilingManager::GetInstance().IsProfiling()) {
@ -542,10 +530,7 @@ bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) {
LaunchDataDump(graph->graph_id());
if (!ModelRunner::Instance().LoadModelComplete(model_iter->first)) {
MS_LOG(ERROR) << "Call ge runtime LoadModelComplete failed";
return false;
}
ModelRunner::Instance().LoadModelComplete(model_iter->first);
return true;
}
@ -730,8 +715,6 @@ bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
ge::InputData input_tensors = ge::InputData();
ge::OutputData *output_tensors = nullptr;
if (GraphWithEmptyTaskList(graph)) {
MS_LOG(WARNING) << "RunTask end, no task info found";
return true;
@ -742,8 +725,9 @@ bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) {
return false;
}
bool status = ModelRunner::Instance().RunModel(graph->graph_id(), input_tensors, output_tensors);
if (!status) {
try {
ModelRunner::Instance().RunModel(graph->graph_id());
} catch (const std::exception &) {
DumpTaskExceptionInfo(graph);
std::string file_name = "task_error_debug" + std::to_string(graph->graph_id()) + ".ir";
auto graph_tmp = std::make_shared<session::KernelGraph>(*graph);
@ -988,7 +972,7 @@ void AscendKernelRuntime::KernelLaunchProfiling(const std::string &kernel_name)
}
uint64_t AscendKernelRuntime::GetAvailableMemMaxSize() const {
auto ascend_mem_manager = dynamic_pointer_cast<AscendMemoryManager>(mem_manager_);
auto ascend_mem_manager = std::dynamic_pointer_cast<AscendMemoryManager>(mem_manager_);
return ascend_mem_manager->GetDeviceMemSize();
}

View File

@ -25,15 +25,15 @@
#include <unordered_set>
#include "runtime/device/kernel_runtime.h"
#include "runtime/context.h"
#include "framework/ge_runtime/davinci_model.h"
#include "runtime/device/ascend/ge_runtime/davinci_model.h"
#include "runtime/device/kernel_runtime_manager.h"
#include "backend/session/session_basic.h"
#include "runtime/device/ascend/dump/data_dumper.h"
using ge::model_runner::TaskInfo;
using std::unordered_map;
using std::vector;
namespace mindspore::device::ascend {
using ge::model_runner::TaskInfo;
class AscendKernelRuntime : public KernelRuntime {
public:
AscendKernelRuntime() = default;

View File

@ -16,6 +16,7 @@
#include <algorithm>
#include "runtime/device/ascend/ascend_memory_pool.h"
#include "runtime/mem.h"
#include "runtime/device/ascend/ascend_kernel_runtime.h"
#include "utils/log_adapter.h"

View File

@ -0,0 +1,92 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_DAVINCI_MODEL_H_
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_DAVINCI_MODEL_H_
#include <memory>
#include <vector>
#include "runtime/device/ascend/ge_runtime/task_info.h"
namespace mindspore::ge::model_runner {
class DavinciModel {
public:
DavinciModel(const std::vector<std::shared_ptr<TaskInfo>> &task_info_list,
const std::vector<uint32_t> &wait_active_stream_list,
const std::vector<uint32_t> &force_copy_stream_list, uint64_t mem_size = 0, uint64_t weight_size = 0,
uint64_t var_size = 0, uintptr_t logic_mem_base = 0, uintptr_t logic_weight_base = 0,
uintptr_t logic_var_base = 0, uint32_t stream_num = 0, uint32_t batch_num = 0, uint32_t event_num = 0,
int32_t priority = 0)
: task_info_list_(task_info_list),
wait_active_stream_list_(wait_active_stream_list),
force_copy_stream_list_(force_copy_stream_list),
mem_size_(mem_size),
weight_size_(weight_size),
var_size_(var_size),
logic_mem_base_(logic_mem_base),
logic_weight_base_(logic_weight_base),
logic_var_base_(logic_var_base),
stream_num_(stream_num),
batch_num_(batch_num),
event_num_(event_num),
priority_(priority) {}
~DavinciModel() {}
uint64_t GetMemSize() const { return mem_size_; }
uint64_t GetWeightSize() const { return weight_size_; }
uint64_t GetVarSize() const { return var_size_; }
uintptr_t GetLogicMemBase() const { return logic_mem_base_; }
uintptr_t GetLogicWeightBase() const { return logic_weight_base_; }
uintptr_t GetLogicVarBase() const { return logic_var_base_; }
uint32_t GetStreamNum() const { return stream_num_; }
uint32_t GetBatchNum() const { return batch_num_; }
uint32_t GetEventNum() const { return event_num_; }
const std::vector<uint32_t> &GetWaitActiveStreams() const { return wait_active_stream_list_; }
const std::vector<uint32_t> &GetForceCopyStreams() const { return force_copy_stream_list_; }
int32_t GetPriority() const { return priority_; }
const std::vector<std::shared_ptr<TaskInfo>> &GetTaskInfoList() const { return task_info_list_; }
private:
std::vector<std::shared_ptr<TaskInfo>> task_info_list_;
std::vector<uint32_t> wait_active_stream_list_;
std::vector<uint32_t> force_copy_stream_list_;
uint64_t mem_size_;
uint64_t weight_size_;
uint64_t var_size_;
uintptr_t logic_mem_base_;
uintptr_t logic_weight_base_;
uintptr_t logic_var_base_;
uint32_t stream_num_;
uint32_t batch_num_;
uint32_t event_num_;
int32_t priority_;
// Disable to copy constructor and assignment operator
DavinciModel &operator=(const DavinciModel &) = delete;
DavinciModel(const DavinciModel &) = delete;
};
} // namespace mindspore::ge::model_runner
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_DAVINCI_MODEL_H_

View File

@ -0,0 +1,59 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_MODEL_CONTEXT_H_
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_MODEL_CONTEXT_H_
#include <vector>
#include "runtime/rt_model.h"
namespace mindspore::ge::model_runner {
class ModelContext {
public:
ModelContext(uint32_t device_id, uint64_t session_id, int32_t priority, rtModel_t rt_model_handle,
rtStream_t rt_model_stream, const std::vector<rtStream_t> &stream_list,
const std::vector<rtLabel_t> &label_list, const std::vector<rtEvent_t> &event_list)
: device_id_(device_id),
session_id_(session_id),
priority_(priority),
rt_model_handle_(rt_model_handle),
rt_model_stream_(rt_model_stream),
stream_list_(stream_list),
label_list_(label_list),
event_list_(event_list) {}
~ModelContext() {}
uint64_t device_id() const { return device_id_; }
uint64_t session_id() const { return session_id_; }
int32_t priority() const { return priority_; }
const rtModel_t &rt_model_handle() const { return rt_model_handle_; }
const rtStream_t &rt_model_stream() const { return rt_model_stream_; }
const std::vector<rtStream_t> &stream_list() const { return stream_list_; }
const std::vector<rtLabel_t> &label_list() const { return label_list_; }
const std::vector<rtEvent_t> &event_list() const { return event_list_; }
private:
uint32_t device_id_;
uint64_t session_id_;
int32_t priority_;
rtModel_t rt_model_handle_;
rtStream_t rt_model_stream_;
std::vector<rtStream_t> stream_list_;
std::vector<rtLabel_t> label_list_;
std::vector<rtEvent_t> event_list_;
};
} // namespace mindspore::ge::model_runner
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_MODEL_CONTEXT_H_

View File

@ -0,0 +1,104 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/ascend/ge_runtime/model_runner.h"
#include "runtime/device/ascend/ge_runtime/runtime_model.h"
#include "runtime/device/ascend/ge_runtime/davinci_model.h"
#include "mindspore/core/utils/log_adapter.h"
namespace mindspore::ge::model_runner {
ModelRunner &ModelRunner::Instance() {
static ModelRunner instance; // Guaranteed to be destroyed.
return instance;
}
void ModelRunner::LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint32_t model_id,
const std::shared_ptr<DavinciModel> &davinci_model) {
std::shared_ptr<RuntimeModel> model = std::make_shared<RuntimeModel>();
model->Load(device_id, session_id, davinci_model);
runtime_models_[model_id] = model;
}
void ModelRunner::DistributeTask(uint32_t model_id) {
auto model_iter = runtime_models_.find(model_id);
if (model_iter == runtime_models_.end()) {
MS_LOG(EXCEPTION) << "Model id " << model_id << " not found.";
}
MS_EXCEPTION_IF_NULL(model_iter->second);
model_iter->second->DistributeTask();
}
void ModelRunner::LoadModelComplete(uint32_t model_id) {
auto model_iter = runtime_models_.find(model_id);
if (model_iter == runtime_models_.end()) {
MS_LOG(EXCEPTION) << "Model id " << model_id << " not found.";
}
MS_EXCEPTION_IF_NULL(model_iter->second);
model_iter->second->LoadComplete();
}
const std::vector<uint32_t> &ModelRunner::GetTaskIdList(uint32_t model_id) const {
auto model_iter = runtime_models_.find(model_id);
if (model_iter == runtime_models_.end()) {
MS_LOG(EXCEPTION) << "Model id " << model_id << " not found.";
}
MS_EXCEPTION_IF_NULL(model_iter->second);
return model_iter->second->GetTaskIdList();
}
const std::vector<uint32_t> &ModelRunner::GetStreamIdList(uint32_t model_id) const {
auto model_iter = runtime_models_.find(model_id);
if (model_iter == runtime_models_.end()) {
MS_LOG(EXCEPTION) << "Model id " << model_id << " not found.";
}
MS_EXCEPTION_IF_NULL(model_iter->second);
return model_iter->second->GetStreamIdList();
}
const std::map<std::string, std::shared_ptr<RuntimeInfo>> &ModelRunner::GetRuntimeInfoMap(uint32_t model_id) const {
auto model_iter = runtime_models_.find(model_id);
if (model_iter == runtime_models_.end()) {
MS_LOG(EXCEPTION) << "Model id " << model_id << " not found.";
}
MS_EXCEPTION_IF_NULL(model_iter->second);
return model_iter->second->GetRuntimeInfoMap();
}
void *ModelRunner::GetModelHandle(uint32_t model_id) const {
auto model_iter = runtime_models_.find(model_id);
if (model_iter == runtime_models_.end()) {
MS_LOG(EXCEPTION) << "Model id " << model_id << " not found.";
}
MS_EXCEPTION_IF_NULL(model_iter->second);
return model_iter->second->GetModelHandle();
}
void ModelRunner::UnloadModel(uint32_t model_id) {
auto iter = runtime_models_.find(model_id);
if (iter != runtime_models_.end()) {
(void)runtime_models_.erase(iter);
}
}
void ModelRunner::RunModel(uint32_t model_id) {
auto model_iter = runtime_models_.find(model_id);
if (model_iter == runtime_models_.end()) {
MS_LOG(EXCEPTION) << "Model id " << model_id << " not found.";
}
MS_EXCEPTION_IF_NULL(model_iter->second);
model_iter->second->Run();
}
} // namespace mindspore::ge::model_runner

View File

@ -0,0 +1,60 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_MODEL_RUNNER_H_
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_MODEL_RUNNER_H_
#include <memory>
#include <map>
#include <vector>
#include <tuple>
#include <string>
#include "runtime/device/ascend/ge_runtime/davinci_model.h"
namespace mindspore::ge::model_runner {
class RuntimeModel;
using RuntimeInfo = std::tuple<uint32_t, uint32_t, void *>;
class ModelRunner {
public:
static ModelRunner &Instance();
void LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint32_t model_id,
const std::shared_ptr<DavinciModel> &davinci_model);
void DistributeTask(uint32_t model_id);
void LoadModelComplete(uint32_t model_id);
const std::vector<uint32_t> &GetTaskIdList(uint32_t model_id) const;
const std::vector<uint32_t> &GetStreamIdList(uint32_t model_id) const;
const std::map<std::string, std::shared_ptr<RuntimeInfo>> &GetRuntimeInfoMap(uint32_t model_id) const;
void *GetModelHandle(uint32_t model_id) const;
void UnloadModel(uint32_t model_id);
void RunModel(uint32_t model_id);
private:
ModelRunner() = default;
~ModelRunner() = default;
std::map<uint32_t, std::shared_ptr<RuntimeModel>> runtime_models_;
};
} // namespace mindspore::ge::model_runner
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_MODEL_RUNNER_H_

View File

@ -0,0 +1,292 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/ascend/ge_runtime/runtime_model.h"
#include <set>
#include "runtime/kernel.h"
#include "runtime/rt_model.h"
#include "graphengine/inc/external/runtime/rt_error_codes.h"
#include "runtime/device/ascend/ge_runtime/model_context.h"
#include "runtime/device/ascend/ge_runtime/task/task.h"
#include "runtime/device/ascend/ge_runtime/task/task_factory.h"
#include "mindspore/core/utils/log_adapter.h"
namespace mindspore::ge::model_runner {
RuntimeModel::~RuntimeModel() {
MS_LOG(INFO) << "RuntimeModel destructor start.";
// Unbind rtModel from all task related streams
RtModelUnbindStream();
// Release task first, hccl task hold stream
task_list_.clear();
// Release all task related streams
RtStreamDestory();
// Release rtlabel resource
RtLabelDestory();
// Release rtEvent resourece
RtEventDestory();
MS_LOG(INFO) << "Do RtModelDestroy";
// Release all rt_model
RtModelDestory();
}
void RuntimeModel::InitStream(const std::shared_ptr<DavinciModel> &davinci_model) {
MS_EXCEPTION_IF_NULL(davinci_model);
std::set<int64_t> wait_active_streams;
std::set<int64_t> force_copy_streams;
for (const auto &stream_id : davinci_model->GetWaitActiveStreams()) {
MS_LOG(INFO) << "Stream id " << stream_id << " is wait active stream.";
(void)wait_active_streams.insert(stream_id);
}
for (const auto &stream_id : davinci_model->GetForceCopyStreams()) {
MS_LOG(INFO) << "Stream id " << stream_id << " is force copy stream.";
(void)force_copy_streams.insert(stream_id);
}
MS_LOG(INFO) << "Total stream num " << davinci_model->GetStreamNum();
for (uint32_t i = 0; i < davinci_model->GetStreamNum(); ++i) {
rtStream_t stream = nullptr;
uint32_t flag = (force_copy_streams.find(i) != force_copy_streams.end())
? (RT_STREAM_PERSISTENT | RT_STREAM_FORCE_COPY)
: (RT_STREAM_PERSISTENT);
rtError_t rt_ret = rtStreamCreateWithFlags(&stream, davinci_model->GetPriority(), flag);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rt api rtStreamCreate failed, ret: " << std::hex << rt_ret;
}
MS_LOG(INFO) << "rtStreamCreateWithFlags end.";
stream_list_.emplace_back(stream);
// Bind rt_model_handle_ to all task related streams
flag = (wait_active_streams.find(i) != wait_active_streams.end()) ? (static_cast<uint32_t>(RT_INVALID_FLAG))
: (static_cast<uint32_t>(RT_HEAD_STREAM));
rt_ret = rtModelBindStream(rt_model_handle_, stream, flag);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rt api rtModelBindStream failed, ret: " << std::hex << rt_ret;
}
MS_LOG(INFO) << "stream index: " << i << ", stream: " << std::hex << stream;
}
}
void RuntimeModel::InitEvent(uint32_t event_num) {
MS_LOG(INFO) << "Event number: " << event_num;
for (uint32_t i = 0; i < event_num; ++i) {
rtEvent_t rt_event;
rtError_t rt_ret = rtEventCreate(&rt_event);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rt api rtEventCreate failed, ret: " << std::hex << rt_ret;
}
event_list_.push_back(rt_event);
}
}
void RuntimeModel::InitLabel(const std::shared_ptr<DavinciModel> &davinci_model) {
MS_LOG(INFO) << "Label number: " << davinci_model->GetBatchNum();
label_list_.resize(davinci_model->GetBatchNum());
for (auto &task_info : davinci_model->GetTaskInfoList()) {
MS_EXCEPTION_IF_NULL(task_info);
if (task_info->type() != TaskInfoType::LABEL_SET) {
continue;
}
auto label_set_task_info = std::static_pointer_cast<LabelSetTaskInfo>(task_info);
if (label_set_task_info->stream_id() >= stream_list_.size()) {
MS_LOG(EXCEPTION) << "Invalid stream id " << label_set_task_info->stream_id() << " total stream num "
<< stream_list_.size();
}
rtLabel_t rt_label = nullptr;
rtError_t rt_ret = rtLabelCreateEx(&rt_label, stream_list_[label_set_task_info->stream_id()]);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rt api rtLabelCreate failed, ret: " << std::hex << rt_ret;
}
label_list_[label_set_task_info->label_id()] = rt_label;
}
}
void RuntimeModel::InitResource(const std::shared_ptr<DavinciModel> &davinci_model) {
MS_LOG(INFO) << "InitResource start";
MS_EXCEPTION_IF_NULL(davinci_model);
rtError_t rt_ret = rtModelCreate(&rt_model_handle_, 0);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rt api rtModelCreate failed, ret: " << std::hex << rt_ret;
}
// Create rtStream for rt_model_handle_
rt_ret = rtStreamCreate(&rt_model_stream_, davinci_model->GetPriority());
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rt api rtStreamCreate failed, ret: " << std::hex << rt_ret;
}
MS_LOG(INFO) << "rtStreamCreate end";
InitStream(davinci_model);
InitEvent(davinci_model->GetEventNum());
InitLabel(davinci_model);
MS_LOG(INFO) << "InitResource success";
}
void RuntimeModel::GenerateTask(uint32_t device_id, uint64_t session_id,
const std::shared_ptr<DavinciModel> &davinci_model) {
MS_LOG(INFO) << "GenerateTask start.";
MS_EXCEPTION_IF_NULL(davinci_model);
auto task_infos = davinci_model->GetTaskInfoList();
ModelContext model_context(device_id, session_id, davinci_model->GetPriority(), rt_model_handle_, rt_model_stream_,
stream_list_, label_list_, event_list_);
for (auto &task_info : task_infos) {
auto task = TaskFactory::GetInstance().Create(model_context, task_info);
task_list_.push_back(task);
}
MS_LOG(INFO) << "GenerateTask success.";
}
void RuntimeModel::LoadComplete() {
uint32_t task_id = 0;
uint32_t stream_id = 0;
auto rt_ret = rtModelGetTaskId(rt_model_handle_, &task_id, &stream_id);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rt api rtModelGetTaskId failed, ret: " << std::hex << rt_ret;
}
task_id_list_.push_back(task_id);
stream_id_list_.push_back(stream_id);
rt_ret = rtModelLoadComplete(rt_model_handle_);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rt api rtModelLoadComplete failed, ret: " << std::hex << rt_ret;
}
}
void RuntimeModel::Load(uint32_t device_id, uint64_t session_id, const std::shared_ptr<DavinciModel> &davinci_model) {
InitResource(davinci_model);
GenerateTask(device_id, session_id, davinci_model);
}
void RuntimeModel::DistributeTask() {
MS_LOG(INFO) << "DistributeTask start.";
for (auto &task : task_list_) {
MS_EXCEPTION_IF_NULL(task);
task->Distribute();
uint32_t task_id = 0;
uint32_t stream_id = 0;
rtError_t rt_ret = rtModelGetTaskId(rt_model_handle_, &task_id, &stream_id);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rt api rtModelGetTaskId failed, ret: " << std::hex << rt_ret;
}
task_id_list_.push_back(task_id);
stream_id_list_.push_back(stream_id);
if (task->Args() != nullptr) {
std::shared_ptr<RuntimeInfo> runtime_tuple = std::make_shared<RuntimeInfo>(task_id, stream_id, task->Args());
auto emplace_ret = runtime_info_map_.emplace(task->task_name(), runtime_tuple);
if (!emplace_ret.second) {
MS_LOG(WARNING) << "Task name exist: " << task->task_name();
}
}
}
if (task_list_.empty()) {
MS_LOG(EXCEPTION) << "Task list is empty";
}
MS_LOG(INFO) << "DistributeTask success.";
}
void RuntimeModel::Run() {
MS_LOG(INFO) << "Davinci task run start.";
rtError_t ret = rtModelExecute(rt_model_handle_, rt_model_stream_, 0);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rt api rtModelLoadComplete failed, ret: " << std::hex << ret;
}
MS_LOG(INFO) << "Run rtModelExecute success, start to rtStreamSynchronize.";
ret = rtStreamSynchronize(rt_model_stream_);
if (ret != RT_ERROR_NONE) {
if (ret == ACL_ERROR_RT_END_OF_SEQUENCE) {
MS_LOG(INFO) << "Model stream ACL_ERROR_RT_END_OF_SEQUENCE signal received.";
return;
}
MS_LOG(EXCEPTION) << "Call rt api rtStreamSynchronize failed, ret: " << std::hex << ret;
}
MS_LOG(INFO) << "Davinci task run success.";
}
void RuntimeModel::RtModelUnbindStream() noexcept {
for (size_t i = 0; i < stream_list_.size(); i++) {
if (rtModelUnbindStream(rt_model_handle_, stream_list_[i]) != RT_ERROR_NONE) {
MS_LOG(ERROR) << "Unbind stream from model failed! Index: " << i;
return;
}
}
}
void RuntimeModel::RtStreamDestory() noexcept {
if (rtStreamDestroy(rt_model_stream_) != RT_ERROR_NONE) {
MS_LOG(ERROR) << "Destroy stream for rt_model failed!";
return;
}
for (size_t i = 0; i < stream_list_.size(); i++) {
if (rtStreamDestroy(stream_list_[i]) != RT_ERROR_NONE) {
MS_LOG(ERROR) << "Destroy stream failed! Index: " << i;
return;
}
}
}
void RuntimeModel::RtLabelDestory() noexcept {
for (size_t i = 0; i < label_list_.size(); i++) {
if (label_list_[i] == nullptr) {
continue;
}
if (rtLabelDestroy(label_list_[i]) != RT_ERROR_NONE) {
MS_LOG(ERROR) << "Destroy label failed! Index: " << i;
return;
}
}
}
void RuntimeModel::RtModelDestory() noexcept {
rtError_t ret = rtModelDestroy(rt_model_handle_);
if (ret != RT_ERROR_NONE) {
MS_LOG(ERROR) << "Call rt api rtModelDestroy failed, ret: " << std::hex << ret;
return;
}
}
void RuntimeModel::RtEventDestory() noexcept {
for (size_t i = 0; i < event_list_.size(); i++) {
if (rtEventDestroy(event_list_[i]) != RT_ERROR_NONE) {
MS_LOG(ERROR) << "Destroy event failed! Index: " << i;
return;
}
}
}
const std::vector<uint32_t> &RuntimeModel::GetTaskIdList() const { return task_id_list_; }
const std::vector<uint32_t> &RuntimeModel::GetStreamIdList() const { return stream_id_list_; }
} // namespace mindspore::ge::model_runner

View File

@ -0,0 +1,71 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_RUNTIME_MODEL_H_
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_RUNTIME_MODEL_H_
#include <map>
#include <memory>
#include <string>
#include <vector>
#include <tuple>
#include "runtime/base.h"
#include "runtime/rt_model.h"
#include "runtime/device/ascend/ge_runtime/davinci_model.h"
namespace mindspore::ge::model_runner {
using RuntimeInfo = std::tuple<uint32_t, uint32_t, void *>;
class Task;
class RuntimeModel {
public:
RuntimeModel() = default;
~RuntimeModel();
void Load(uint32_t device_id, uint64_t session_id, const std::shared_ptr<DavinciModel> &davinci_model);
void DistributeTask();
void LoadComplete();
const std::vector<uint32_t> &GetTaskIdList() const;
const std::vector<uint32_t> &GetStreamIdList() const;
const std::map<std::string, std::shared_ptr<RuntimeInfo>> &GetRuntimeInfoMap() const { return runtime_info_map_; }
rtModel_t GetModelHandle() const { return rt_model_handle_; }
void Run();
private:
void InitResource(const std::shared_ptr<DavinciModel> &davinci_model);
void GenerateTask(uint32_t device_id, uint64_t session_id, const std::shared_ptr<DavinciModel> &davinci_model);
void InitStream(const std::shared_ptr<DavinciModel> &davinci_model);
void InitEvent(uint32_t event_num);
void InitLabel(const std::shared_ptr<DavinciModel> &davinci_model);
void RtModelUnbindStream() noexcept;
void RtStreamDestory() noexcept;
void RtModelDestory() noexcept;
void RtLabelDestory() noexcept;
void RtEventDestory() noexcept;
rtModel_t rt_model_handle_{};
rtStream_t rt_model_stream_{};
std::vector<rtStream_t> stream_list_{};
std::vector<rtLabel_t> label_list_{};
std::vector<rtEvent_t> event_list_{};
std::vector<std::shared_ptr<Task>> task_list_{};
std::vector<uint32_t> task_id_list_{};
std::vector<uint32_t> stream_id_list_{};
std::map<std::string, std::shared_ptr<RuntimeInfo>> runtime_info_map_;
};
} // namespace mindspore::ge::model_runner
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_RUNTIME_MODEL_H_

View File

@ -0,0 +1,168 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/ascend/ge_runtime/task/aicpu_task.h"
#include <vector>
#include "runtime/mem.h"
#include "runtime/kernel.h"
#include "runtime/device/ascend/ge_runtime/task/task_factory.h"
#include "aicpu/common/aicpu_task_struct.h"
namespace mindspore::ge::model_runner {
AicpuTask::AicpuTask(const ModelContext &model_context, const std::shared_ptr<AicpuTaskInfo> &task_info)
: TaskRepeater<AicpuTaskInfo>(model_context, task_info),
task_info_(task_info),
stream_(nullptr),
args_(nullptr),
ext_info_(nullptr),
input_output_addr_(nullptr) {
MS_EXCEPTION_IF_NULL(task_info_);
auto stream_list = model_context.stream_list();
if (stream_list.size() == 1) {
stream_ = stream_list[0];
} else if (stream_list.size() > task_info_->stream_id()) {
stream_ = stream_list[task_info_->stream_id()];
} else {
MS_LOG(EXCEPTION) << "Index: " << task_info_->stream_id() << " >= stream_list.size(): " << stream_list.size();
}
}
AicpuTask::~AicpuTask() {
ReleaseRtMem(&args_);
ReleaseRtMem(&ext_info_);
}
void AicpuTask::Distribute() {
MS_LOG(INFO) << "InitAicpuTask start.";
std::vector<void *> io_addrs;
io_addrs.insert(io_addrs.end(), task_info_->input_data_addrs().begin(), task_info_->input_data_addrs().end());
io_addrs.insert(io_addrs.end(), task_info_->output_data_addrs().begin(), task_info_->output_data_addrs().end());
auto io_addrs_num = static_cast<uint32_t>(io_addrs.size());
auto io_addrs_size = static_cast<uint32_t>(io_addrs_num * sizeof(void *));
constexpr uint32_t io_addr_offset = sizeof(aicpu::AicpuParamHead);
uint32_t node_def_len_offset = io_addr_offset + io_addrs_size;
uint32_t node_def_addr_offset = node_def_len_offset + sizeof(uint32_t);
uint32_t args_size = sizeof(aicpu::AicpuParamHead) + io_addrs_size +
static_cast<uint32_t>(task_info_->node_def().size()) + sizeof(uint32_t);
// Malloc device memory for args
rtError_t rt_ret = rtMalloc(&args_, args_size, RT_MEMORY_HBM);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rt api rtMalloc failed, ret: " << std::hex << rt_ret;
}
SetAicpuParamHead(args_size, io_addrs_num);
SetInputOutputAddrs(io_addrs, io_addr_offset);
SetNodeDef(node_def_len_offset, node_def_addr_offset);
// for data dump
input_output_addr_ = reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + io_addr_offset);
auto dump_flag = task_info_->dump_flag() ? RT_KERNEL_DUMPFLAG : RT_KERNEL_DEFAULT;
MS_LOG(INFO) << "Distribute AicpuTask start, args_size = " << args_size << ", io_addrs_num =" << io_addrs_num
<< ", so_name = " << task_info_->so_name() << ", kernel_name = " << task_info_->kernel_name()
<< ", dump_flag = " << dump_flag;
rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast<const void *>(task_info_->so_name().data()),
reinterpret_cast<const void *>(task_info_->kernel_name().data()), 1, args_,
args_size, nullptr, stream_, dump_flag);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rt api rtCpuKernelLaunchWithFlag failed, ret: " << std::hex << rt_ret;
}
MS_LOG(INFO) << "Distribute AicpuTask end.";
}
void AicpuTask::ReleaseRtMem(void **ptr) noexcept {
if (ptr == nullptr || *ptr == nullptr) {
return;
}
rtError_t rt_ret = rtFree(*ptr);
if (rt_ret != RT_ERROR_NONE) {
return;
}
*ptr = nullptr;
}
void AicpuTask::SetAicpuParamHead(uint32_t args_size, uint32_t io_addrs_num) {
aicpu::AicpuParamHead aicpu_param_head;
aicpu_param_head.length = args_size;
aicpu_param_head.ioAddrNum = io_addrs_num;
const auto &ext_info = task_info_->ext_info();
uint32_t ext_size = ext_info.size();
if (ext_info.empty()) {
aicpu_param_head.extInfoLength = 0;
aicpu_param_head.extInfoAddr = 0;
} else {
rtError_t flag = rtMalloc(&ext_info_, ext_size, RT_MEMORY_HBM);
if (flag != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rt api rtMalloc failed, ret: " << std::hex << flag;
}
flag = rtMemcpy(ext_info_, ext_size, const_cast<void *>(reinterpret_cast<const void *>(ext_info.data())), ext_size,
RT_MEMCPY_HOST_TO_DEVICE);
if (flag != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rt api rtMemcpy failed, ret: " << std::hex << flag;
}
MS_LOG(INFO) << "ext info size: " << ext_size;
aicpu_param_head.extInfoLength = ext_size;
aicpu_param_head.extInfoAddr = reinterpret_cast<uintptr_t>(ext_info_);
}
// Memcpy AicpuParamHead
auto rt_ret = rtMemcpy(args_, sizeof(aicpu::AicpuParamHead), reinterpret_cast<void *>(&aicpu_param_head),
sizeof(aicpu::AicpuParamHead), RT_MEMCPY_HOST_TO_DEVICE);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rt api rtMemcpy failed, ret: " << std::hex << rt_ret;
}
}
void AicpuTask::SetInputOutputAddrs(const std::vector<void *> &io_addrs, uint32_t io_addr_offset) {
// Memcpy io addrs
if (!io_addrs.empty()) {
auto rt_ret = rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + io_addr_offset),
static_cast<uint32_t>(io_addrs.size()) * sizeof(void *), io_addrs.data(),
static_cast<uint32_t>(io_addrs.size()) * sizeof(void *), RT_MEMCPY_HOST_TO_DEVICE);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rt api rtMemcpy failed, ret: " << std::hex << rt_ret;
}
}
}
void AicpuTask::SetNodeDef(uint32_t node_def_len_offset, uint32_t node_def_addr_offset) {
// Memcpy node def
auto size = task_info_->node_def().size();
auto rt_ret =
rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + node_def_len_offset), sizeof(uint32_t),
reinterpret_cast<const void *>(&size), sizeof(uint32_t), RT_MEMCPY_HOST_TO_DEVICE);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rt api rtMemcpy failed, ret: " << std::hex << rt_ret;
}
// Memcpy node def
rt_ret = rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + node_def_addr_offset),
task_info_->node_def().size(), reinterpret_cast<const void *>(task_info_->node_def().data()),
task_info_->node_def().size(), RT_MEMCPY_HOST_TO_DEVICE);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rt api rtMemcpy failed, ret: " << std::hex << rt_ret;
}
}
REGISTER_TASK(TaskInfoType::AICPU, AicpuTask, AicpuTaskInfo);
} // namespace mindspore::ge::model_runner

View File

@ -0,0 +1,51 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_AICPU_TASK_H_
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_AICPU_TASK_H_
#include <vector>
#include <memory>
#include <string>
#include "runtime/device/ascend/ge_runtime/task/task.h"
namespace mindspore::ge::model_runner {
class AicpuTask : public TaskRepeater<AicpuTaskInfo> {
public:
AicpuTask(const ModelContext &model_context, const std::shared_ptr<AicpuTaskInfo> &task_info);
~AicpuTask() override;
void Distribute() override;
void *Args() override { return input_output_addr_; }
std::string task_name() const override { return task_info_->op_name(); }
private:
static void ReleaseRtMem(void **ptr) noexcept;
void SetAicpuParamHead(uint32_t args_size, uint32_t io_addrs_num);
void SetInputOutputAddrs(const std::vector<void *> &io_addrs, uint32_t io_addr_offset);
void SetNodeDef(uint32_t node_def_len_offset, uint32_t node_def_addr_offset);
std::shared_ptr<AicpuTaskInfo> task_info_;
void *stream_;
void *args_;
void *ext_info_;
void *input_output_addr_;
};
} // namespace mindspore::ge::model_runner
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_AICPU_TASK_H_

View File

@ -0,0 +1,54 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/ascend/ge_runtime/task/event_record_task.h"
#include "runtime/kernel.h"
#include "runtime/device/ascend/ge_runtime/task/task_factory.h"
namespace mindspore::ge::model_runner {
EventRecordTask::EventRecordTask(const ModelContext &model_context,
const std::shared_ptr<EventRecordTaskInfo> &task_info)
: TaskRepeater<EventRecordTaskInfo>(model_context, task_info),
task_info_(task_info),
stream_(nullptr),
event_(nullptr) {
MS_EXCEPTION_IF_NULL(task_info_);
auto stream_list = model_context.stream_list();
auto event_list = model_context.event_list();
uint32_t stream_id = task_info_->stream_id();
uint32_t event_id = task_info_->event_id();
if (stream_id >= stream_list.size() || event_id >= event_list.size()) {
MS_LOG(EXCEPTION) << "stream_list size: " << stream_list.size() << ", stream_id: " << stream_id
<< ", event_list size: " << event_list.size() << ", event_id: " << event_id;
}
stream_ = stream_list[stream_id];
event_ = event_list[event_id];
}
EventRecordTask::~EventRecordTask() {}
void EventRecordTask::Distribute() {
MS_LOG(INFO) << "EventRecordTask Distribute start, stream: " << stream_ << ", event: " << event_
<< ", stream_id: " << task_info_->stream_id() << ", event_id: " << task_info_->event_id();
rtError_t rt_ret = rtEventRecord(event_, stream_);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rt api rtEventRecord failed, ret: " << std::hex << rt_ret;
}
MS_LOG(INFO) << "Distribute end.";
}
REGISTER_TASK(TaskInfoType::EVENT_RECORD, EventRecordTask, EventRecordTaskInfo);
} // namespace mindspore::ge::model_runner

View File

@ -0,0 +1,38 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_EVENT_RECORD_TASK_H_
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_EVENT_RECORD_TASK_H_
#include <memory>
#include "runtime/device/ascend/ge_runtime/task/task.h"
namespace mindspore::ge::model_runner {
class EventRecordTask : public TaskRepeater<EventRecordTaskInfo> {
public:
EventRecordTask(const ModelContext &model_context, const std::shared_ptr<EventRecordTaskInfo> &task_info);
~EventRecordTask() override;
void Distribute() override;
private:
std::shared_ptr<EventRecordTaskInfo> task_info_;
rtStream_t stream_;
rtEvent_t event_;
};
} // namespace mindspore::ge::model_runner
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_EVENT_RECORD_TASK_H_

View File

@ -0,0 +1,59 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/ascend/ge_runtime/task/event_wait_task.h"
#include "runtime/kernel.h"
#include "runtime/device/ascend/ge_runtime/task/task_factory.h"
namespace mindspore::ge::model_runner {
EventWaitTask::EventWaitTask(const ModelContext &model_context, const std::shared_ptr<EventWaitTaskInfo> &task_info)
: TaskRepeater<EventWaitTaskInfo>(model_context, task_info),
task_info_(task_info),
stream_(nullptr),
event_(nullptr) {
MS_EXCEPTION_IF_NULL(task_info_);
auto stream_list = model_context.stream_list();
auto event_list = model_context.event_list();
uint32_t stream_id = task_info_->stream_id();
uint32_t event_id = task_info_->event_id();
if (stream_id >= stream_list.size() || event_id >= event_list.size()) {
MS_LOG(EXCEPTION) << "stream_list size: " << stream_list.size() << ", stream_id: " << stream_id
<< ", event_list size: " << event_list.size() << ", event_id: " << event_id;
}
stream_ = stream_list[stream_id];
event_ = event_list[event_id];
}
EventWaitTask::~EventWaitTask() {}
void EventWaitTask::Distribute() {
MS_LOG(INFO) << "EventWaitTask Distribute start, stream: " << stream_ << ", event: " << event_
<< ", stream_id: " << task_info_->stream_id() << ", event_id: " << task_info_->event_id();
rtError_t rt_ret = rtStreamWaitEvent(stream_, event_);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rt api rtStreamWaitEvent failed, ret: " << std::hex << rt_ret;
}
rt_ret = rtEventReset(event_, stream_);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rt api rtEventReset failed, ret: " << std::hex << rt_ret;
}
MS_LOG(INFO) << "Distribute end.";
}
REGISTER_TASK(TaskInfoType::EVENT_WAIT, EventWaitTask, EventWaitTaskInfo);
} // namespace mindspore::ge::model_runner

View File

@ -0,0 +1,38 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_EVENT_WAIT_TASK_H_
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_EVENT_WAIT_TASK_H_
#include <memory>
#include "runtime/device/ascend/ge_runtime/task/task.h"
namespace mindspore::ge::model_runner {
class EventWaitTask : public TaskRepeater<EventWaitTaskInfo> {
public:
EventWaitTask(const ModelContext &model_context, const std::shared_ptr<EventWaitTaskInfo> &task_info);
~EventWaitTask() override;
void Distribute() override;
private:
std::shared_ptr<EventWaitTaskInfo> task_info_;
rtStream_t stream_;
rtEvent_t event_;
};
} // namespace mindspore::ge::model_runner
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_EVENT_WAIT_TASK_H_

View File

@ -0,0 +1,221 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/ascend/ge_runtime/task/hccl_task.h"
#include <algorithm>
#include "runtime/device/ascend/ge_runtime/task/task_factory.h"
#include "common/opskernel/ops_kernel_info_store.h"
#include "common/opskernel/ge_task_info.h"
namespace mindspore::ge::model_runner {
std::map<rtModel_t, std::map<uint32_t, std::vector<std::weak_ptr<HcclTask::StreamGuard>>>>
HcclTask::model_stream_mapping_;
std::mutex HcclTask::model_stream_mapping_mutex_;
HcclTask::HcclTask(const ModelContext &model_context, const std::shared_ptr<HcclTaskInfo> &task_info)
: TaskRepeater<HcclTaskInfo>(model_context, task_info),
task_info_(task_info),
stream_(nullptr),
workspace_mem_(nullptr),
rt_model_handle_(nullptr),
priority_(0),
secondary_stream_list_() {
MS_EXCEPTION_IF_NULL(task_info_);
priority_ = model_context.priority();
rt_model_handle_ = model_context.rt_model_handle();
auto stream_list = model_context.stream_list();
if (stream_list.size() == 1) {
stream_ = stream_list[0];
} else if (stream_list.size() > task_info_->stream_id()) {
stream_ = stream_list[task_info_->stream_id()];
} else {
MS_LOG(EXCEPTION) << "Index: " << task_info_->stream_id() << " >= stream_list.size(): " << stream_list.size();
}
}
HcclTask::~HcclTask() {}
void HcclTask::Distribute() {
// Ops kernel info store
// Get privateDef and opsKernelStorePtr
MS_LOG(INFO) << "Distribute hccl task start.";
void *ops_kernel_store = task_info_->ops_kernel_store();
::ge::OpsKernelInfoStore *ops_kernel_info_store = reinterpret_cast<::ge::OpsKernelInfoStore *>(ops_kernel_store);
MS_EXCEPTION_IF_NULL(ops_kernel_info_store);
char *private_def = reinterpret_cast<char *>(const_cast<char unsigned *>(task_info_->private_def().data()));
auto private_def_len = static_cast<uint32_t>(task_info_->private_def().size());
MS_LOG(INFO) << "The first address of the custom info, privateDef= " << private_def;
SetSecondaryStream();
if (task_info_->workspace_size() > 0) {
workspace_mem_ = task_info_->workspace_addr();
}
::ge::GETaskInfo ge_task;
ge_task.id = 0;
ge_task.type = static_cast<uint16_t>(RT_MODEL_TASK_HCCL);
ge_task.stream = stream_;
ge_task.kernelHcclInfo = std::vector<::ge::GETaskKernelHcclInfo>(1);
ge_task.kernelHcclInfo[0].hccl_type = task_info_->hccl_type();
ge_task.kernelHcclInfo[0].inputDataAddr = task_info_->input_data_addr();
ge_task.kernelHcclInfo[0].outputDataAddr = task_info_->output_data_addr();
ge_task.kernelHcclInfo[0].workSpaceAddr = workspace_mem_;
ge_task.kernelHcclInfo[0].workSpaceMemSize = task_info_->workspace_size();
ge_task.kernelHcclInfo[0].count = task_info_->count();
ge_task.kernelHcclInfo[0].dataType = static_cast<int32_t>(task_info_->data_type());
ge_task.kernelHcclInfo[0].opType = static_cast<int32_t>(task_info_->op_type());
ge_task.kernelHcclInfo[0].rootId = task_info_->root_id();
std::vector<rtStream_t> secondary_stream_list;
std::transform(secondary_stream_list_.begin(), secondary_stream_list_.end(),
std::back_inserter(secondary_stream_list),
[](const std::shared_ptr<StreamGuard> &stream) -> rtStream_t { return stream->GetStream(); });
ge_task.kernelHcclInfo[0].hcclStreamList = secondary_stream_list;
ge_task.privateDef = private_def;
ge_task.privateDefLen = private_def_len;
ge_task.opsKernelStorePtr = ops_kernel_store;
MS_LOG(INFO) << "Begin to call function LoadTask in hccl.";
auto result = ops_kernel_info_store->LoadTask(ge_task);
// tagHcclResult::HCCL_SUCCESS is 0
if (result != 0) {
MS_LOG(EXCEPTION) << "davinci_model : load task fail, return ret: " << result;
}
MS_LOG(INFO) << "Call function LoadTask end.";
}
void HcclTask::SetSecondaryStream() {
const uint32_t master_stream_id = task_info_->stream_id();
const int64_t hccl_secondary_stream_num = task_info_->hccl_stream_num();
std::lock_guard<std::mutex> lock(model_stream_mapping_mutex_);
// no model, create all secondary stream
auto model_iter = model_stream_mapping_.find(rt_model_handle_);
if (model_iter == model_stream_mapping_.end()) {
MS_LOG(INFO) << "Need to create map for rt_model_handle_: " << rt_model_handle_ << " with new mainstream "
<< master_stream_id;
CreateStream(hccl_secondary_stream_num, master_stream_id);
MS_LOG(INFO) << "Initialize hccl secondary stream success, hccl_secondary_stream_num=" << hccl_secondary_stream_num;
return;
}
// has model, but no secondary stream before, create all secondary stream
auto &master_secondary_stream_map = model_iter->second;
auto iter = master_secondary_stream_map.find(master_stream_id);
if (iter == master_secondary_stream_map.end()) {
MS_LOG(INFO) << "Need to create secondary stream for " << task_info_->op_name() << " with new mainstream "
<< master_stream_id;
CreateStream(hccl_secondary_stream_num, master_stream_id);
MS_LOG(INFO) << "Initialize hccl secondary stream success, hccl_secondary_stream_num=" << hccl_secondary_stream_num;
return;
}
// has model, has secondary stream, but number is not enough to be reuse
std::vector<std::weak_ptr<StreamGuard>> &secondary_stream_vec = iter->second;
if (static_cast<size_t>(hccl_secondary_stream_num) > secondary_stream_vec.size()) {
size_t created_stream_num = secondary_stream_vec.size();
auto need_to_create_num = hccl_secondary_stream_num - created_stream_num;
MS_LOG(INFO) << "Need to reuse " << secondary_stream_vec.size() << " secondary stream and create "
<< need_to_create_num << " new secondary stream.";
for (size_t i = 0; i < secondary_stream_vec.size(); ++i) {
secondary_stream_list_.push_back(GetSecondaryStream(&secondary_stream_vec, i));
}
CreateStream(need_to_create_num, master_stream_id);
MS_LOG(INFO) << "Initialize hccl secondary stream success, hccl_secondary_stream_num=" << hccl_secondary_stream_num;
return;
}
// all can be reuse
MS_LOG(INFO) << "Number of secondary stream " << hccl_secondary_stream_num << " is enough to be reused.";
for (int64_t i = 0; i < hccl_secondary_stream_num; ++i) {
secondary_stream_list_.push_back(GetSecondaryStream(&secondary_stream_vec, i));
}
MS_LOG(INFO) << "Initialize hccl secondary stream success, hccl_secondary_stream_num = " << hccl_secondary_stream_num;
}
void HcclTask::CreateStream(int64_t stream_num, int64_t master_stream_id) {
MS_LOG(INFO) << "Start to create " << stream_num << " hccl secondary stream.";
for (int64_t i = 0; i < stream_num; ++i) {
rtStream_t stream = nullptr;
CreateStream(rt_model_handle_, &stream);
auto shared_stream = std::make_shared<StreamGuard>(rt_model_handle_, stream);
SaveHcclSecondaryStream(master_stream_id, shared_stream);
secondary_stream_list_.push_back(shared_stream);
}
MS_LOG(INFO) << "CreateStream success.";
}
void HcclTask::CreateStream(rtModel_t model, rtStream_t *stream) const {
MS_EXCEPTION_IF_NULL(stream);
rtError_t rt_ret = rtStreamCreateWithFlags(stream, priority_, RT_STREAM_PERSISTENT | RT_STREAM_FORCE_COPY);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rt api rtEventRecord failed, ret: " << std::hex << rt_ret;
}
// Create secondary stream, inactive by default, activated by hccl
rt_ret = rtModelBindStream(model, *stream, RT_MODEL_WAIT_ACTIVE_STREAM);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rt api rtEventRecord failed, ret: " << std::hex << rt_ret;
}
}
void HcclTask::SaveHcclSecondaryStream(int64_t master_stream_id, const std::shared_ptr<StreamGuard> &stream) {
if (model_stream_mapping_.find(rt_model_handle_) == model_stream_mapping_.end()) {
model_stream_mapping_.emplace(rt_model_handle_, std::map<uint32_t, std::vector<std::weak_ptr<StreamGuard>>>());
}
std::map<uint32_t, std::vector<std::weak_ptr<StreamGuard>>> &master_secondary_stream_map =
model_stream_mapping_.at(rt_model_handle_);
master_secondary_stream_map[master_stream_id].emplace_back(stream);
}
HcclTask::StreamGuard::~StreamGuard() {
rtError_t rt_ret = rtModelUnbindStream(model_, stream_);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(ERROR) << "Call rt api rtModelUnbindStream failed, ret: " << std::hex << rt_ret;
return;
}
rt_ret = rtStreamDestroy(stream_);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(ERROR) << "Call rt api rtStreamDestroy failed, ret: " << std::hex << rt_ret;
return;
}
}
std::shared_ptr<HcclTask::StreamGuard> HcclTask::GetSecondaryStream(
std::vector<std::weak_ptr<StreamGuard>> *secondary_streams, size_t index) {
MS_EXCEPTION_IF_NULL(secondary_streams);
if (index >= secondary_streams->size()) {
MS_LOG(EXCEPTION) << "Invalid stream index " << index << ", secondary streams size " << secondary_streams->size();
}
auto stream = secondary_streams->at(index).lock();
if (stream == nullptr) {
rtStream_t new_stream = nullptr;
CreateStream(rt_model_handle_, &new_stream);
stream = std::make_shared<HcclTask::StreamGuard>(rt_model_handle_, new_stream);
(*secondary_streams)[index] = stream;
}
return stream;
}
REGISTER_TASK(TaskInfoType::HCCL, HcclTask, HcclTaskInfo);
} // namespace mindspore::ge::model_runner

View File

@ -0,0 +1,68 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_HCCL_TASK_H_
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_HCCL_TASK_H_
#include <memory>
#include <set>
#include <map>
#include <vector>
#include <mutex>
#include "runtime/device/ascend/ge_runtime/task/task.h"
namespace mindspore::ge::model_runner {
class HcclTask : public TaskRepeater<HcclTaskInfo> {
public:
HcclTask(const ModelContext &model_context, const std::shared_ptr<HcclTaskInfo> &task_info);
~HcclTask() override;
void Distribute() override;
private:
class StreamGuard;
void SetSecondaryStream();
void CreateStream(int64_t stream_num, int64_t master_stream_id);
void CreateStream(rtModel_t model, rtStream_t *stream) const;
void SaveHcclSecondaryStream(int64_t master_stream_id, const std::shared_ptr<StreamGuard> &stream);
std::shared_ptr<StreamGuard> GetSecondaryStream(std::vector<std::weak_ptr<StreamGuard>> *secondary_streams,
size_t index);
std::shared_ptr<HcclTaskInfo> task_info_;
void *stream_;
void *workspace_mem_;
rtModel_t rt_model_handle_;
int32_t priority_;
std::vector<std::shared_ptr<StreamGuard>> secondary_stream_list_;
// map<key: model pointer, value: map<key: primary stream id, value: vector<secondary stream pointer>>>
static std::map<rtModel_t, std::map<uint32_t, std::vector<std::weak_ptr<StreamGuard>>>> model_stream_mapping_;
static std::mutex model_stream_mapping_mutex_;
};
class HcclTask::StreamGuard {
public:
StreamGuard(rtModel_t model, rtStream_t stream) : model_(model), stream_(stream) {}
~StreamGuard();
rtStream_t GetStream() const { return stream_; }
private:
rtModel_t model_;
rtStream_t stream_;
};
} // namespace mindspore::ge::model_runner
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_HCCL_TASK_H_

View File

@ -0,0 +1,83 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/ascend/ge_runtime/task/label_goto_task.h"
#include "runtime/mem.h"
#include "runtime/device/ascend/ge_runtime/task/task_factory.h"
namespace mindspore::ge::model_runner {
LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info)
: TaskRepeater<LabelGotoTaskInfo>(model_context, task_info),
task_info_(task_info),
stream_(nullptr),
index_value_(nullptr) {
MS_EXCEPTION_IF_NULL(task_info_);
auto stream_list = model_context.stream_list();
auto label_list = model_context.label_list();
rt_model_handle_ = model_context.rt_model_handle();
uint32_t stream_id = task_info_->stream_id();
label_id_ = task_info_->label_id();
MS_LOG(INFO) << "Stream list size: " << stream_list.size() << ", stream id: " << stream_id;
MS_LOG(INFO) << "Label list size: " << label_list.size() << ", label id: " << label_id_;
if (stream_id >= stream_list.size() || label_id_ >= label_list.size()) {
MS_LOG(EXCEPTION) << "Stream/Label id invalid.";
}
stream_ = stream_list[stream_id];
label_manager_ = LabelManager::GetInstance();
MS_EXCEPTION_IF_NULL(label_manager_);
label_info_ = label_manager_->GetLabelInfo(rt_model_handle_, {label_id_}, label_list);
MS_EXCEPTION_IF_NULL(label_info_);
}
LabelGotoTask::~LabelGotoTask() {
if (index_value_ != nullptr) {
rtError_t rt_ret = rtFree(index_value_);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(ERROR) << "Call rtFree index_value_ failed, ret: " << std::hex << rt_ret;
}
index_value_ = nullptr;
}
}
void LabelGotoTask::Distribute() {
MS_LOG(INFO) << "LabelGotoTask Distribute start.";
MS_EXCEPTION_IF_NULL(stream_);
MS_EXCEPTION_IF_NULL(label_info_);
if (index_value_ == nullptr) {
rtError_t rt_ret = rtMalloc(&index_value_, sizeof(uint64_t), RT_MEMORY_HBM);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rt api rtMalloc failed, ret: " << std::hex << rt_ret;
}
uint64_t index = 0;
rt_ret = rtMemcpy(index_value_, sizeof(uint64_t), &index, sizeof(index), RT_MEMCPY_HOST_TO_DEVICE);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rt api rtMemcpy failed, ret: " << std::hex << rt_ret;
}
}
void *label_info = label_info_->GetLabelInfo();
rtError_t rt_ret = rtLabelSwitchByIndex(index_value_, 1, label_info, stream_);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rt api rtLabelSwitchByIndex failed, ret: " << std::hex << rt_ret;
}
MS_LOG(INFO) << "DistributeTask end.";
}
REGISTER_TASK(TaskInfoType::LABEL_GOTO, LabelGotoTask, LabelGotoTaskInfo);
} // namespace mindspore::ge::model_runner

View File

@ -0,0 +1,46 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_
#include <memory>
#include <vector>
#include <map>
#include <mutex>
#include "runtime/device/ascend/ge_runtime/task/task.h"
#include "runtime/device/ascend/ge_runtime/task/label_manager.h"
namespace mindspore::ge::model_runner {
class LabelGotoTask : public TaskRepeater<LabelGotoTaskInfo> {
public:
LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info);
~LabelGotoTask() override;
void Distribute() override;
private:
std::shared_ptr<LabelGotoTaskInfo> task_info_;
void *stream_;
std::shared_ptr<LabelGuard> label_info_;
void *index_value_;
uint32_t label_id_;
rtModel_t rt_model_handle_;
std::shared_ptr<LabelManager> label_manager_;
};
} // namespace mindspore::ge::model_runner
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_

View File

@ -0,0 +1,116 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/ascend/ge_runtime/task/label_manager.h"
#include <algorithm>
#include <string>
#include "runtime/mem.h"
#include "runtime/rt_model.h"
#include "mindspore/core/utils/log_adapter.h"
namespace mindspore::ge::model_runner {
std::weak_ptr<LabelManager> LabelManager::instance_;
std::mutex LabelManager::instance_mutex_;
template <class T>
static std::string GetVectorString(const std::vector<T> &vec) {
std::string ret;
for (size_t i = 0; i < vec.size(); ++i) {
if (i != 0) {
ret.push_back(',');
}
ret += std::to_string(vec[i]);
}
return ret;
}
LabelGuard::~LabelGuard() {
void *label_info = GetLabelInfo();
if (label_info != nullptr) {
rtError_t rt_ret = rtFree(label_info);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(ERROR) << "rtFree label_info failed! ret: " << std::hex << rt_ret;
}
}
}
std::shared_ptr<LabelManager> LabelManager::GetInstance() {
std::lock_guard<std::mutex> lock(instance_mutex_);
auto instance = instance_.lock();
if (instance != nullptr) {
return instance;
}
instance = std::make_shared<LabelManager>();
instance_ = instance;
return instance;
}
std::shared_ptr<LabelGuard> LabelManager::GetLabelInfo(rtModel_t model, const std::vector<uint32_t> &label_ids,
const std::vector<void *> &all_label) {
std::lock_guard<std::mutex> lock(model_info_mapping_mutex_);
rtError_t rt_ret;
auto model_iter = model_info_mapping_.find(model);
if (model_iter == model_info_mapping_.end()) {
model_info_mapping_.emplace(model, std::map<std::string, std::weak_ptr<LabelGuard>>());
model_iter = model_info_mapping_.find(model);
}
std::string label_id_str = GetVectorString(label_ids);
auto &label_map = model_iter->second;
auto label_iter = label_map.find(label_id_str);
if (label_iter != label_map.end()) {
auto label_guard = label_iter->second.lock();
if (label_guard != nullptr) {
MS_LOG(INFO) << "model " << model << " find same label id " << label_id_str;
return label_guard;
}
}
MS_LOG(INFO) << "Alloc label id " << label_id_str << " for model " << model;
void *label_info = nullptr;
std::vector<void *> label_list;
bool status = true;
std::transform(label_ids.begin(), label_ids.end(), std::back_inserter(label_list),
[&all_label, &status](uint32_t idx) -> void * {
if (idx >= all_label.size()) {
MS_LOG(ERROR) << "Invalid label id " << idx << " all label list size " << all_label.size();
status = false;
return nullptr;
}
return all_label[idx];
});
if (!status) {
MS_LOG(ERROR) << "Get label info failed.";
return nullptr;
}
uint32_t label_info_size = sizeof(rtLabelDevInfo) * label_list.size();
rt_ret = rtMalloc(&label_info, label_info_size, RT_MEMORY_HBM);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(ERROR) << "Call rt api rtMalloc failed, ret: " << std::hex << rt_ret;
return nullptr;
}
rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info, label_info_size);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(ERROR) << "Call rt api rtLabelListCpy failed, ret: " << std::hex << rt_ret;
return nullptr;
}
auto label_guard = std::make_shared<LabelGuard>(label_info);
label_map.emplace(label_id_str, label_guard);
return label_guard;
}
} // namespace mindspore::ge::model_runner

View File

@ -0,0 +1,51 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_MANAGER_H_
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_MANAGER_H_
#include <vector>
#include <memory>
#include <mutex>
#include <map>
#include <string>
#include "runtime/base.h"
namespace mindspore::ge::model_runner {
class LabelGuard {
public:
explicit LabelGuard(void *label_info) : label_info_(reinterpret_cast<uintptr_t>(label_info)) {}
~LabelGuard();
void *GetLabelInfo() { return reinterpret_cast<void *>(label_info_); }
private:
uintptr_t label_info_;
};
class LabelManager {
public:
static std::shared_ptr<LabelManager> GetInstance();
std::shared_ptr<LabelGuard> GetLabelInfo(rtModel_t model, const std::vector<uint32_t> &label_ids,
const std::vector<void *> &all_label);
private:
std::mutex model_info_mapping_mutex_;
std::map<rtModel_t, std::map<std::string, std::weak_ptr<LabelGuard>>> model_info_mapping_;
static std::weak_ptr<LabelManager> instance_;
static std::mutex instance_mutex_;
};
} // namespace mindspore::ge::model_runner
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_MANAGER_H_

View File

@ -0,0 +1,56 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/ascend/ge_runtime/task/label_set_task.h"
#include "runtime/device/ascend/ge_runtime/task/task_factory.h"
namespace mindspore::ge::model_runner {
LabelSetTask::LabelSetTask(const ModelContext &model_context, const std::shared_ptr<LabelSetTaskInfo> &task_info)
: TaskRepeater<LabelSetTaskInfo>(model_context, task_info),
task_info_(task_info),
stream_(nullptr),
label_(nullptr) {
MS_EXCEPTION_IF_NULL(task_info_);
auto stream_list = model_context.stream_list();
auto label_list = model_context.label_list();
uint32_t stream_id = task_info->stream_id();
uint32_t label_id = task_info->label_id();
MS_LOG(INFO) << "Stream list size: " << stream_list.size() << ", stream id: " << stream_id;
MS_LOG(INFO) << "Label list size: " << label_list.size() << ", label id: " << label_id;
if (stream_id >= stream_list.size() || label_id >= label_list.size()) {
MS_LOG(EXCEPTION) << "Stream/Label id invalid.";
}
stream_ = stream_list[stream_id];
label_ = label_list[label_id];
}
LabelSetTask::~LabelSetTask() {}
void LabelSetTask::Distribute() {
MS_LOG(INFO) << "LabelSetTask Distribute start.";
MS_EXCEPTION_IF_NULL(stream_);
MS_EXCEPTION_IF_NULL(label_);
rtError_t rt_ret = rtLabelSet(label_, stream_);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rt api rtLabelSet failed, ret: " << std::hex << rt_ret;
}
MS_LOG(INFO) << "DistributeTask end.";
}
REGISTER_TASK(TaskInfoType::LABEL_SET, LabelSetTask, LabelSetTaskInfo);
} // namespace mindspore::ge::model_runner

View File

@ -0,0 +1,38 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_SET_TASK_H_
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_SET_TASK_H_
#include <memory>
#include "runtime/device/ascend/ge_runtime/task/task.h"
namespace mindspore::ge::model_runner {
class LabelSetTask : public TaskRepeater<LabelSetTaskInfo> {
public:
LabelSetTask(const ModelContext &model_context, const std::shared_ptr<LabelSetTaskInfo> &task_info);
~LabelSetTask() override;
void Distribute() override;
private:
std::shared_ptr<LabelSetTaskInfo> task_info_;
void *stream_;
void *label_;
};
} // namespace mindspore::ge::model_runner
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_SET_TASK_H_

View File

@ -0,0 +1,77 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/ascend/ge_runtime/task/label_switch_task.h"
#include "runtime/device/ascend/ge_runtime/task/task_factory.h"
namespace mindspore::ge::model_runner {
LabelSwitchTask::LabelSwitchTask(const ModelContext &model_context,
const std::shared_ptr<LabelSwitchTaskInfo> &task_info)
: TaskRepeater<LabelSwitchTaskInfo>(model_context, task_info),
task_info_(task_info),
stream_(nullptr),
label_info_(nullptr) {
MS_EXCEPTION_IF_NULL(task_info);
rt_model_handle_ = model_context.rt_model_handle();
auto all_label_resource = model_context.label_list();
auto stream_list = model_context.stream_list();
uint32_t stream_id = task_info->stream_id();
MS_LOG(INFO) << "Stream list size: " << stream_list.size() << ", stream id: " << stream_id;
if (stream_id >= stream_list.size()) {
MS_LOG(EXCEPTION) << "Stream id invalid.";
}
stream_ = stream_list[stream_id];
label_manager_ = LabelManager::GetInstance();
MS_EXCEPTION_IF_NULL(label_manager_);
label_info_ = label_manager_->GetLabelInfo(rt_model_handle_, task_info_->label_list(), all_label_resource);
MS_EXCEPTION_IF_NULL(label_info_);
}
LabelSwitchTask::~LabelSwitchTask() {}
void LabelSwitchTask::Distribute() {
MS_LOG(INFO) << "LabelSwitchTask Distribute start.";
CheckParamValid();
void *label_info = label_info_->GetLabelInfo();
rtError_t rt_ret = rtLabelSwitchByIndex(task_info_->cond(), task_info_->label_size(), label_info, stream_);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rt api rtLabelSwitchByIndex failed, ret: " << std::hex << rt_ret;
}
MS_LOG(INFO) << "DistributeTask end.";
}
void LabelSwitchTask::CheckParamValid() {
MS_EXCEPTION_IF_NULL(stream_);
if (task_info_->label_list().empty()) {
MS_LOG(EXCEPTION) << "label_list is empty.";
}
if (task_info_->label_size() != task_info_->label_list().size()) {
MS_LOG(EXCEPTION) << "label_list size " << task_info_->label_list().size() << " but label_size is "
<< task_info_->label_size();
}
if (task_info_->label_size() >= UINT32_MAX / sizeof(rtLabelDevInfo)) {
MS_LOG(EXCEPTION) << "label_size " << task_info_->label_size() << " will overflow.";
}
}
REGISTER_TASK(TaskInfoType::LABEL_SWITCH, LabelSwitchTask, LabelSwitchTaskInfo);
} // namespace mindspore::ge::model_runner

View File

@ -0,0 +1,43 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_
#include <memory>
#include "runtime/device/ascend/ge_runtime/task/task.h"
#include "runtime/device/ascend/ge_runtime/task/label_manager.h"
namespace mindspore::ge::model_runner {
class LabelSwitchTask : public TaskRepeater<LabelSwitchTaskInfo> {
public:
LabelSwitchTask(const ModelContext &model_context, const std::shared_ptr<LabelSwitchTaskInfo> &task_info);
~LabelSwitchTask() override;
void Distribute() override;
private:
void CheckParamValid();
std::shared_ptr<LabelSwitchTaskInfo> task_info_;
void *stream_;
rtModel_t rt_model_handle_;
std::shared_ptr<LabelGuard> label_info_;
std::shared_ptr<LabelManager> label_manager_;
};
} // namespace mindspore::ge::model_runner
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_

View File

@ -0,0 +1,51 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/ascend/ge_runtime/task/memcpy_async_task.h"
#include "runtime/mem.h"
#include "runtime/device/ascend/ge_runtime/task/task_factory.h"
namespace mindspore::ge::model_runner {
MemcpyAsyncTask::MemcpyAsyncTask(const ModelContext &model_context,
const std::shared_ptr<MemcpyAsyncTaskInfo> &task_info)
: TaskRepeater<MemcpyAsyncTaskInfo>(model_context, task_info), task_info_(task_info), stream_(nullptr) {
MS_EXCEPTION_IF_NULL(task_info);
auto stream_list = model_context.stream_list();
uint32_t stream_id = task_info->stream_id();
MS_LOG(INFO) << "Stream list size: " << stream_list.size() << ", stream id: " << stream_id;
if (stream_id >= stream_list.size()) {
MS_LOG(EXCEPTION) << "Index: " << task_info->stream_id() << " >= stream_list.size(): " << stream_list.size();
}
stream_ = stream_list[stream_id];
}
MemcpyAsyncTask::~MemcpyAsyncTask() {}
void MemcpyAsyncTask::Distribute() {
MS_LOG(INFO) << "MemcpyAsyncTask Distribute start.";
MS_LOG(INFO) << "dst_max: " << task_info_->dst_max() << ", count: " << task_info_->count()
<< ", kind: " << task_info_->kind();
rtError_t rt_ret = rtMemcpyAsync(task_info_->dst(), task_info_->dst_max(), task_info_->src(), task_info_->count(),
static_cast<rtMemcpyKind_t>(task_info_->kind()), stream_);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rt api rtMemcpyAsync failed, ret: " << std::hex << rt_ret;
}
MS_LOG(INFO) << "DistributeTask end";
}
REGISTER_TASK(TaskInfoType::MEMCPY_ASYNC, MemcpyAsyncTask, MemcpyAsyncTaskInfo);
} // namespace mindspore::ge::model_runner

View File

@ -0,0 +1,37 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_MEMCPY_ASYNC_TASK_H_
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_MEMCPY_ASYNC_TASK_H_
#include <memory>
#include "runtime/device/ascend/ge_runtime/task/task.h"
namespace mindspore::ge::model_runner {
class MemcpyAsyncTask : public TaskRepeater<MemcpyAsyncTaskInfo> {
public:
MemcpyAsyncTask(const ModelContext &model_context, const std::shared_ptr<MemcpyAsyncTaskInfo> &task_info);
~MemcpyAsyncTask() override;
void Distribute() override;
private:
std::shared_ptr<MemcpyAsyncTaskInfo> task_info_;
rtStream_t stream_;
};
} // namespace mindspore::ge::model_runner
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_MEMCPY_ASYNC_TASK_H_

View File

@ -0,0 +1,47 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/ascend/ge_runtime/task/profiler_task.h"
#include "runtime/device/ascend/ge_runtime/task/task_factory.h"
namespace mindspore::ge::model_runner {
ProfilerTask::ProfilerTask(const ModelContext &model_context, const std::shared_ptr<ProfilerTraceTaskInfo> &task_info)
: TaskRepeater<ProfilerTraceTaskInfo>(model_context, task_info), task_info_(task_info), stream_(nullptr) {
MS_EXCEPTION_IF_NULL(task_info);
auto stream_list = model_context.stream_list();
uint32_t stream_id = task_info->stream_id();
MS_LOG(INFO) << "Stream list size: " << stream_list.size() << ", stream id: " << stream_id;
if (stream_id >= stream_list.size()) {
MS_LOG(EXCEPTION) << "Index: " << task_info->stream_id() << " >= stream_list.size(): " << stream_list.size();
}
stream_ = stream_list[stream_id];
}
ProfilerTask::~ProfilerTask() {}
void ProfilerTask::Distribute() {
MS_LOG(INFO) << "ProfilerTask Distribute start.";
MS_LOG(INFO) << "log id = " << task_info_->log_id() << ", notify = " << task_info_->notify()
<< ", flat = " << task_info_->flat();
rtError_t rt_ret = rtProfilerTrace(task_info_->log_id(), task_info_->notify(), task_info_->flat(), stream_);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rt api rtProfilerTrace failed, ret: " << std::hex << rt_ret;
}
MS_LOG(INFO) << "DistributeTask end.";
}
REGISTER_TASK(TaskInfoType::PROFILER_TRACE, ProfilerTask, ProfilerTraceTaskInfo);
} // namespace mindspore::ge::model_runner

View File

@ -0,0 +1,37 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_PROFILER_TASK_H_
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_PROFILER_TASK_H_
#include <memory>
#include "runtime/device/ascend/ge_runtime/task/task.h"
namespace mindspore::ge::model_runner {
class ProfilerTask : public TaskRepeater<ProfilerTraceTaskInfo> {
public:
ProfilerTask(const ModelContext &model_context, const std::shared_ptr<ProfilerTraceTaskInfo> &task_info);
~ProfilerTask() override;
void Distribute() override;
private:
std::shared_ptr<ProfilerTraceTaskInfo> task_info_;
rtStream_t stream_;
};
} // namespace mindspore::ge::model_runner
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_PROFILER_TASK_H_

View File

@ -0,0 +1,56 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/ascend/ge_runtime/task/stream_active_task.h"
#include "runtime/kernel.h"
#include "runtime/device/ascend/ge_runtime/task/task_factory.h"
namespace mindspore::ge::model_runner {
StreamActiveTask::StreamActiveTask(const ModelContext &model_context,
const std::shared_ptr<StreamActiveTaskInfo> &task_info)
: TaskRepeater<StreamActiveTaskInfo>(model_context, task_info),
task_info_(task_info),
stream_(nullptr),
active_stream_(nullptr) {
MS_EXCEPTION_IF_NULL(task_info);
auto stream_list = model_context.stream_list();
uint32_t stream_id = task_info->stream_id();
uint32_t active_stream_id = task_info->active_stream_id();
MS_LOG(INFO) << "Stream list size: " << stream_list.size() << ", stream id: " << stream_id
<< ", active stream id: " << active_stream_id;
if (stream_id >= stream_list.size() || active_stream_id >= stream_list.size()) {
MS_LOG(EXCEPTION) << "Stream id invalid";
}
stream_ = stream_list[stream_id];
active_stream_ = stream_list[active_stream_id];
}
StreamActiveTask::~StreamActiveTask() {}
void StreamActiveTask::Distribute() {
MS_LOG(INFO) << "Distribute start";
MS_LOG(INFO) << "Stream " << task_info_->stream_id() << " active " << task_info_->active_stream_id();
MS_EXCEPTION_IF_NULL(stream_);
MS_EXCEPTION_IF_NULL(active_stream_);
rtError_t rt_ret = rtStreamActive(active_stream_, stream_);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rt api rtStreamActive failed, ret: " << std::hex << rt_ret;
}
MS_LOG(INFO) << "DistributeTask end.";
}
REGISTER_TASK(TaskInfoType::STREAM_ACTIVE, StreamActiveTask, StreamActiveTaskInfo);
} // namespace mindspore::ge::model_runner

View File

@ -0,0 +1,38 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_STREAM_ACTIVE_TASK_H_
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_STREAM_ACTIVE_TASK_H_
#include <memory>
#include "runtime/device/ascend/ge_runtime/task/task.h"
namespace mindspore::ge::model_runner {
class StreamActiveTask : public TaskRepeater<StreamActiveTaskInfo> {
public:
StreamActiveTask(const ModelContext &model_context, const std::shared_ptr<StreamActiveTaskInfo> &task_info);
~StreamActiveTask() override;
void Distribute() override;
private:
std::shared_ptr<StreamActiveTaskInfo> task_info_;
rtStream_t stream_;
rtStream_t active_stream_;
};
} // namespace mindspore::ge::model_runner
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_STREAM_ACTIVE_TASK_H_

View File

@ -0,0 +1,70 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/ascend/ge_runtime/task/stream_switch_task.h"
#include "runtime/kernel.h"
#include "runtime/device/ascend/ge_runtime/task/task_factory.h"
namespace mindspore::ge::model_runner {
StreamSwitchTask::StreamSwitchTask(const ModelContext &model_context,
const std::shared_ptr<StreamSwitchTaskInfo> &task_info)
: TaskRepeater<StreamSwitchTaskInfo>(model_context, task_info),
task_info_(task_info),
stream_(nullptr),
stream_list_() {
MS_EXCEPTION_IF_NULL(task_info);
stream_list_ = model_context.stream_list();
if (stream_list_.size() == 1) {
stream_ = stream_list_[0];
} else if (stream_list_.size() > task_info->stream_id()) {
stream_ = stream_list_[task_info->stream_id()];
} else {
MS_LOG(EXCEPTION) << "Index: " << task_info->stream_id() << " >= stream_list.size(): " << stream_list_.size();
}
}
StreamSwitchTask::~StreamSwitchTask() {}
void StreamSwitchTask::Distribute() {
MS_LOG(INFO) << "Init StreamSwitchTask start.";
MS_LOG(INFO) << "Stream " << task_info_->stream_id() << " active " << task_info_->true_stream_id();
MS_EXCEPTION_IF_NULL(stream_);
if (static_cast<uint64_t>(task_info_->true_stream_id()) >= stream_list_.size()) {
MS_LOG(EXCEPTION) << "true_stream_id " << task_info_->true_stream_id() << " must be less than stream_list_ size "
<< stream_list_.size();
}
void *input = reinterpret_cast<void *>(task_info_->input_addr());
rtCondition_t cond = static_cast<rtCondition_t>(task_info_->cond());
void *value = reinterpret_cast<void *>(task_info_->value_addr());
rtStream_t true_stream = stream_list_[task_info_->true_stream_id()];
rtSwitchDataType_t data_type = static_cast<rtSwitchDataType_t>(task_info_->data_type());
MS_LOG(INFO) << "InitStreamSwitchTask, cond: " << cond << ", trueStream: " << true_stream
<< ", trueStreamID: " << task_info_->true_stream_id() << ", datatype: " << task_info_->data_type();
MS_LOG(INFO) << "StreamSwitchTask Distribute Start.";
rtError_t rt_ret = rtStreamSwitchEx(input, cond, value, true_stream, stream_, data_type);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rt api rtStreamSwitchEx failed, ret: " << std::hex << rt_ret;
}
MS_LOG(INFO) << "Distribute StreamSwitch success";
}
REGISTER_TASK(TaskInfoType::STREAM_SWITCH, StreamSwitchTask, StreamSwitchTaskInfo);
} // namespace mindspore::ge::model_runner

View File

@ -0,0 +1,40 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_STREAM_SWITCH_TASK_H_
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_STREAM_SWITCH_TASK_H_
#include <memory>
#include <vector>
#include "runtime/device/ascend/ge_runtime/task/task.h"
namespace mindspore::ge::model_runner {
class StreamSwitchTask : public TaskRepeater<StreamSwitchTaskInfo> {
public:
StreamSwitchTask(const ModelContext &model_context, const std::shared_ptr<StreamSwitchTaskInfo> &task_info);
~StreamSwitchTask() override;
void Distribute() override;
private:
std::shared_ptr<StreamSwitchTaskInfo> task_info_;
void *stream_;
std::vector<rtStream_t> stream_list_;
};
} // namespace mindspore::ge::model_runner
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_STREAM_SWITCH_TASK_H_

View File

@ -0,0 +1,53 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TASK_H_
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TASK_H_
#include <memory>
#include <utility>
#include <vector>
#include <string>
#include "runtime/device/ascend/ge_runtime/model_context.h"
#include "runtime/device/ascend/ge_runtime/task_info.h"
namespace mindspore::ge::model_runner {
class Task {
public:
Task() {}
virtual ~Task() {}
virtual void Distribute() = 0;
virtual void *Args() { return nullptr; }
virtual std::string task_name() const { return ""; }
};
template <class T>
class TaskRepeater : public Task {
static_assert(std::is_base_of<TaskInfo, T>(), "Wrong TaskInfo Type!");
public:
TaskRepeater(const ModelContext &model_context, const std::shared_ptr<T> &task_info) {}
virtual ~TaskRepeater() {}
virtual void Distribute() = 0;
};
} // namespace mindspore::ge::model_runner
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TASK_H_

View File

@ -0,0 +1,84 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TASK_FACTORY_H_
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TASK_FACTORY_H_
#include <functional>
#include <map>
#include <memory>
#include <unordered_map>
#include "runtime/device/ascend/ge_runtime/task_info.h"
#include "mindspore/core/utils/log_adapter.h"
namespace mindspore::ge::model_runner {
class Task;
class ModelContext;
using TASK_CREATOR_FUN = std::function<std::shared_ptr<Task>(const ModelContext &, std::shared_ptr<TaskInfo>)>;
class TaskFactory {
private:
TaskFactory() {}
~TaskFactory() {}
void RegisterCreator(const TaskInfoType &type, const TASK_CREATOR_FUN &func) {
if (creator_map_.find(type) != creator_map_.end()) {
MS_LOG(WARNING) << "Creator type " << type << " already exist.";
}
creator_map_[type] = func;
}
std::map<TaskInfoType, TASK_CREATOR_FUN> creator_map_;
public:
static TaskFactory &GetInstance() {
static TaskFactory instance;
return instance;
}
std::shared_ptr<Task> Create(const ModelContext &model_context, const std::shared_ptr<TaskInfo> &task_info) const {
if (task_info == nullptr) {
MS_LOG(ERROR) << "task_info is null.";
return nullptr;
}
auto iter = creator_map_.find(task_info->type());
if (iter == creator_map_.end()) {
MS_LOG(ERROR) << "Unknown task type " << task_info->type();
return nullptr;
}
return iter->second(model_context, task_info);
}
class Register {
public:
Register(const TaskInfoType &type, const TASK_CREATOR_FUN &func) {
MS_LOG(DEBUG) << "register type " << type;
TaskFactory::GetInstance().RegisterCreator(type, func);
}
~Register() {}
};
};
#define REGISTER_TASK(type, task_clazz, task_info_clazz) \
TaskFactory::Register g_##task_clazz##_register( \
type, [](const ModelContext &model_context, const std::shared_ptr<TaskInfo> &task_info) -> std::shared_ptr<Task> { \
std::shared_ptr<task_info_clazz> concrete_task_info = std::static_pointer_cast<task_info_clazz>(task_info); \
return std::make_shared<task_clazz>(model_context, concrete_task_info); \
});
} // namespace mindspore::ge::model_runner
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TASK_FACTORY_H_

View File

@ -0,0 +1,97 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/ascend/ge_runtime/task/tbe_task.h"
#include <vector>
#include "runtime/mem.h"
#include "runtime/kernel.h"
#include "runtime/device/ascend/ge_runtime/task/task_factory.h"
namespace mindspore::ge::model_runner {
TbeTask::TbeTask(const ModelContext &model_context, const std::shared_ptr<TbeTaskInfo> &task_info)
: TaskRepeater<TbeTaskInfo>(model_context, task_info),
task_info_(task_info),
stream_(nullptr),
stub_func_(nullptr),
args_(nullptr) {
MS_EXCEPTION_IF_NULL(task_info);
auto stream_list = model_context.stream_list();
if (stream_list.size() == 1) {
stream_ = stream_list[0];
} else if (stream_list.size() > task_info->stream_id()) {
stream_ = stream_list[task_info->stream_id()];
} else {
MS_LOG(EXCEPTION) << "Index: " << task_info->stream_id() << " >= stream_list.size(): " << stream_list.size();
}
}
TbeTask::~TbeTask() {
if (args_ != nullptr) {
rtError_t rt_ret = rtFree(args_);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(ERROR) << "Call rt api rtFree failed, ret: " << std::hex << rt_ret;
}
args_ = nullptr;
}
}
void TbeTask::Distribute() {
MS_LOG(INFO) << "InitTbeTask start.";
MS_EXCEPTION_IF_NULL(stream_);
// Get stub_func
if (task_info_->stub_func().empty()) {
MS_LOG(EXCEPTION) << "kernel_info->stub_func is empty!";
}
rtError_t rt_ret = rtGetFunctionByName(const_cast<char *>(task_info_->stub_func().c_str()), &stub_func_);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rt api rtGetFunctionByName failed, ret: " << std::hex << rt_ret;
}
MS_LOG(INFO) << "TbeTask: stub_func = " << task_info_->stub_func();
// Get args
std::vector<void *> tensor_device_addrs;
tensor_device_addrs.insert(tensor_device_addrs.end(), task_info_->input_data_addrs().begin(),
task_info_->input_data_addrs().end());
tensor_device_addrs.insert(tensor_device_addrs.end(), task_info_->output_data_addrs().begin(),
task_info_->output_data_addrs().end());
tensor_device_addrs.insert(tensor_device_addrs.end(), task_info_->workspace_addrs().begin(),
task_info_->workspace_addrs().end());
auto args_size = static_cast<uint32_t>(tensor_device_addrs.size() * sizeof(void *));
rt_ret = rtMalloc(&args_, args_size, RT_MEMORY_HBM);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rt api rtMalloc failed, ret: " << std::hex << rt_ret << " mem size " << args_size;
}
rt_ret = rtMemcpy(args_, args_size, reinterpret_cast<void *>(tensor_device_addrs.data()), args_size,
RT_MEMCPY_HOST_TO_DEVICE);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rt api rtMemcpy failed, ret: " << std::hex << rt_ret;
}
MS_LOG(INFO) << "DistributeTbeTask start.";
auto dump_flag = task_info_->dump_flag() ? RT_KERNEL_DUMPFLAG : RT_KERNEL_DEFAULT;
rt_ret = rtKernelLaunchWithFlag(stub_func_, task_info_->block_dim(), args_, args_size, nullptr, stream_, dump_flag);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rt api rtKernelLaunch failed, ret: " << std::hex << rt_ret << " mem size " << args_size;
}
MS_LOG(INFO) << "[DataDump] task name: " << task_info_->op_name() << " dump_flag: " << dump_flag;
}
REGISTER_TASK(TaskInfoType::TBE, TbeTask, TbeTaskInfo);
} // namespace mindspore::ge::model_runner

View File

@ -0,0 +1,44 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TBE_TASK_H_
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TBE_TASK_H_
#include <string>
#include <memory>
#include "runtime/device/ascend/ge_runtime/task/task.h"
namespace mindspore::ge::model_runner {
class TbeTask : public TaskRepeater<TbeTaskInfo> {
public:
TbeTask(const ModelContext &model_context, const std::shared_ptr<TbeTaskInfo> &task_info);
~TbeTask() override;
void Distribute() override;
void *Args() override { return args_; }
std::string task_name() const override { return task_info_->op_name(); }
private:
std::shared_ptr<TbeTaskInfo> task_info_;
void *stream_;
void *stub_func_;
void *args_;
};
} // namespace mindspore::ge::model_runner
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TBE_TASK_H_

View File

@ -0,0 +1,364 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_INFO_H_
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_INFO_H_
#include <stdint.h>
#include <memory>
#include <string>
#include <utility>
#include <vector>
namespace mindspore::ge::model_runner {
enum TaskInfoType {
CCE = 0,
TBE,
AICPU,
LABEL_SET,
LABEL_SWITCH,
LABEL_GOTO,
EVENT_RECORD,
EVENT_WAIT,
FUSION_START,
FUSION_END,
HCCL,
PROFILER_TRACE,
MEMCPY_ASYNC,
STREAM_SWITCH,
STREAM_ACTIVE,
// Insert new task type here
REVSERVED = 23
};
class TaskInfo {
public:
virtual ~TaskInfo() {}
uint32_t stream_id() const { return stream_id_; }
TaskInfoType type() const { return type_; }
std::string op_name() const { return op_name_; }
bool dump_flag() const { return dump_flag_; }
protected:
TaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, bool dump_flag)
: op_name_(op_name), stream_id_(stream_id), type_(type), dump_flag_(dump_flag) {}
private:
std::string op_name_;
uint32_t stream_id_;
TaskInfoType type_;
bool dump_flag_;
};
class TbeTaskInfo : public TaskInfo {
public:
TbeTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string &stub_func, uint32_t block_dim,
const std::vector<uint8_t> &args, uint32_t args_size, const std::vector<uint8_t> &sm_desc, void *binary,
uint32_t binary_size, const std::vector<uint8_t> &meta_data, const std::vector<void *> &input_data_addrs,
const std::vector<void *> &output_data_addrs, const std::vector<void *> &workspace_addrs, bool dump_flag)
: TaskInfo(op_name, stream_id, TaskInfoType::TBE, dump_flag),
stub_func_(stub_func),
block_dim_(block_dim),
args_(args),
args_size_(args_size),
sm_desc_(sm_desc),
binary_(binary),
binary_size_(binary_size),
meta_data_(meta_data),
input_data_addrs_(input_data_addrs),
output_data_addrs_(output_data_addrs),
workspace_addrs_(workspace_addrs) {}
~TbeTaskInfo() override {}
const std::string &stub_func() const { return stub_func_; }
uint32_t block_dim() const { return block_dim_; }
const std::vector<uint8_t> &args() const { return args_; }
uint32_t args_size() const { return args_size_; }
const std::vector<uint8_t> &sm_desc() const { return sm_desc_; }
void *binary() const { return binary_; }
uint32_t binary_size() const { return binary_size_; }
const std::vector<uint8_t> &meta_data() const { return meta_data_; }
const std::vector<void *> &input_data_addrs() const { return input_data_addrs_; }
const std::vector<void *> &output_data_addrs() const { return output_data_addrs_; }
const std::vector<void *> &workspace_addrs() const { return workspace_addrs_; }
void SetBinary(void *binary, uint32_t binary_size) {
binary_ = binary;
binary_size_ = binary_size;
}
private:
std::string stub_func_;
uint32_t block_dim_;
std::vector<uint8_t> args_;
uint32_t args_size_;
std::vector<uint8_t> sm_desc_;
void *binary_;
uint32_t binary_size_;
std::vector<uint8_t> meta_data_;
std::vector<void *> input_data_addrs_;
std::vector<void *> output_data_addrs_;
std::vector<void *> workspace_addrs_;
};
class AicpuTaskInfo : public TaskInfo {
public:
AicpuTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string &so_name,
const std::string &kernel_name, const std::string &node_def, const std::string &ext_info,
const std::vector<void *> &input_data_addrs, const std::vector<void *> &output_data_addrs,
bool dump_flag)
: TaskInfo(op_name, stream_id, TaskInfoType::AICPU, dump_flag),
so_name_(so_name),
kernel_name_(kernel_name),
node_def_(node_def),
ext_info_(ext_info),
input_data_addrs_(input_data_addrs),
output_data_addrs_(output_data_addrs) {}
~AicpuTaskInfo() override {}
const std::string &so_name() const { return so_name_; }
const std::string &kernel_name() const { return kernel_name_; }
const std::string &node_def() const { return node_def_; }
const std::vector<void *> &input_data_addrs() const { return input_data_addrs_; }
const std::vector<void *> &output_data_addrs() const { return output_data_addrs_; }
const std::string &ext_info() const { return ext_info_; }
private:
std::string so_name_;
std::string kernel_name_;
std::string node_def_;
std::string ext_info_;
std::vector<void *> input_data_addrs_;
std::vector<void *> output_data_addrs_;
};
class LabelSetTaskInfo : public TaskInfo {
public:
LabelSetTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id)
: TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SET, false), label_id_(label_id) {}
~LabelSetTaskInfo() override {}
uint32_t label_id() const { return label_id_; }
private:
uint32_t label_id_;
};
class LabelGotoTaskInfo : public TaskInfo {
public:
LabelGotoTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id)
: TaskInfo(op_name, stream_id, TaskInfoType::LABEL_GOTO, false), label_id_(label_id) {}
~LabelGotoTaskInfo() override {}
uint32_t label_id() const { return label_id_; }
private:
uint32_t label_id_;
};
class LabelSwitchTaskInfo : public TaskInfo {
public:
LabelSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_size,
const std::vector<uint32_t> &label_list, void *cond)
: TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SWITCH, false),
label_size_(label_size),
label_list_(label_list),
cond_(cond) {}
~LabelSwitchTaskInfo() override {}
uint32_t label_size() const { return label_size_; }
const std::vector<uint32_t> &label_list() const { return label_list_; }
void *cond() const { return cond_; }
private:
uint32_t label_size_;
std::vector<uint32_t> label_list_;
void *cond_;
};
class EventTaskInfo : public TaskInfo {
public:
uint32_t event_id() const { return event_id_; }
protected:
EventTaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, uint32_t event_id)
: TaskInfo(op_name, stream_id, type, false), event_id_(event_id) {}
~EventTaskInfo() override {}
uint32_t event_id_;
};
class EventRecordTaskInfo : public EventTaskInfo {
public:
EventRecordTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id)
: EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_RECORD, event_id) {}
~EventRecordTaskInfo() override {}
};
class EventWaitTaskInfo : public EventTaskInfo {
public:
EventWaitTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id)
: EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_WAIT, event_id) {}
~EventWaitTaskInfo() override {}
};
class FusionStartTaskInfo : public TaskInfo {
public:
explicit FusionStartTaskInfo(const std::string &op_name, uint32_t stream_id)
: TaskInfo(op_name, stream_id, TaskInfoType::FUSION_START, false) {}
~FusionStartTaskInfo() override {}
};
class FusionEndTaskInfo : public TaskInfo {
public:
explicit FusionEndTaskInfo(const std::string &op_name, uint32_t stream_id)
: TaskInfo(op_name, stream_id, TaskInfoType::FUSION_END, false) {}
~FusionEndTaskInfo() override {}
};
class HcclTaskInfo : public TaskInfo {
public:
HcclTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string hccl_type, void *input_data_addr,
void *output_data_addr, void *workspace_addr, int64_t workspace_size, int64_t hccl_stream_num,
const std::vector<uint8_t> &private_def, void *ops_kernel_store, int32_t count, int64_t root_id,
int64_t op_type, int64_t data_type, const std::string &group, bool dump_flag)
: TaskInfo(op_name, stream_id, TaskInfoType::HCCL, dump_flag),
hccl_type_(hccl_type),
input_data_addr_(input_data_addr),
output_data_addr_(output_data_addr),
workspace_addr_(workspace_addr),
workspace_size_(workspace_size),
hccl_stream_num_(hccl_stream_num),
private_def_(private_def),
ops_kernel_store_(ops_kernel_store),
count_(count),
root_id_(root_id),
op_type_(op_type),
data_type_(data_type),
group_(group) {}
~HcclTaskInfo() override {}
const std::string &hccl_type() const { return hccl_type_; }
void *input_data_addr() const { return input_data_addr_; }
void *output_data_addr() const { return output_data_addr_; }
void *workspace_addr() const { return workspace_addr_; }
int64_t workspace_size() const { return workspace_size_; }
int64_t hccl_stream_num() const { return hccl_stream_num_; }
const std::vector<uint8_t> &private_def() const { return private_def_; }
void *ops_kernel_store() const { return ops_kernel_store_; }
int32_t count() const { return count_; }
int64_t root_id() const { return root_id_; }
int64_t op_type() const { return op_type_; }
int64_t data_type() const { return data_type_; }
const std::string &group() const { return group_; }
private:
std::string hccl_type_;
void *input_data_addr_;
void *output_data_addr_;
void *workspace_addr_;
int64_t workspace_size_;
int64_t hccl_stream_num_;
std::vector<uint8_t> private_def_;
void *ops_kernel_store_;
int32_t count_;
int64_t root_id_;
int64_t op_type_;
int64_t data_type_;
std::string group_;
};
class ProfilerTraceTaskInfo : public TaskInfo {
public:
ProfilerTraceTaskInfo(const std::string &op_name, uint32_t stream_id, uint64_t log_id, bool notify, uint32_t flat)
: TaskInfo(op_name, stream_id, TaskInfoType::PROFILER_TRACE, false),
log_id_(log_id),
notify_(notify),
flat_(flat) {}
~ProfilerTraceTaskInfo() override {}
uint64_t log_id() const { return log_id_; }
bool notify() const { return notify_; }
uint32_t flat() const { return flat_; }
private:
uint64_t log_id_;
bool notify_;
uint32_t flat_;
};
class MemcpyAsyncTaskInfo : public TaskInfo {
public:
MemcpyAsyncTaskInfo(const std::string &op_name, uint32_t stream_id, void *dst, uint64_t dst_max, void *src,
uint64_t count, uint32_t kind, bool dump_flag)
: TaskInfo(op_name, stream_id, TaskInfoType::MEMCPY_ASYNC, dump_flag),
dst_(dst),
dst_max_(dst_max),
src_(src),
count_(count),
kind_(kind) {}
~MemcpyAsyncTaskInfo() override {}
void *dst() const { return dst_; }
uint64_t dst_max() const { return dst_max_; }
void *src() const { return src_; }
uint64_t count() const { return count_; }
uint32_t kind() const { return kind_; }
private:
void *dst_;
uint64_t dst_max_;
void *src_;
uint64_t count_;
int32_t kind_;
};
class StreamSwitchTaskInfo : public TaskInfo {
public:
StreamSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, int64_t true_stream_id, void *input_addr,
void *value_addr, int64_t cond, int64_t data_type)
: TaskInfo(op_name, stream_id, TaskInfoType::STREAM_SWITCH, false),
true_stream_id_(true_stream_id),
input_addr_(input_addr),
value_addr_(value_addr),
cond_(cond),
data_type_(data_type) {}
~StreamSwitchTaskInfo() override {}
int64_t true_stream_id() const { return true_stream_id_; }
void *input_addr() const { return input_addr_; }
void *value_addr() const { return value_addr_; }
int64_t cond() const { return cond_; }
int64_t data_type() const { return data_type_; }
private:
int64_t true_stream_id_;
void *input_addr_;
void *value_addr_;
int64_t cond_;
int64_t data_type_;
};
class StreamActiveTaskInfo : public TaskInfo {
public:
StreamActiveTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t active_stream_id)
: TaskInfo(op_name, stream_id, TaskInfoType::STREAM_ACTIVE, false), active_stream_id_(active_stream_id) {}
~StreamActiveTaskInfo() override {}
uint32_t active_stream_id() const { return active_stream_id_; }
private:
uint32_t active_stream_id_;
};
} // namespace mindspore::ge::model_runner
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_INFO_H_

View File

@ -25,7 +25,7 @@
#include "runtime/device/kernel_runtime.h"
#include "ir/anf.h"
#include "backend/kernel_compiler/ascend_kernel_mod.h"
#include "framework/ge_runtime/task_info.h"
#include "runtime/device/ascend/ge_runtime/task_info.h"
namespace mindspore {
namespace device {

View File

@ -134,6 +134,7 @@ enum SubModuleId : int {
SM_HCCL_ADPT, // Hccl Adapter
SM_MINDQUANTUM, // MindQuantum
SM_RUNTIME_FRAMEWORK, // Runtime framework
SM_GE, // GraphEngine
NUM_SUBMODUES // number of submodules
};
@ -142,34 +143,35 @@ enum SubModuleId : int {
#endif
static const char *SUB_MODULE_NAMES[NUM_SUBMODUES] = {
"UNKNOWN", // SM_UNKNOWN
"CORE", // SM_CORE
"ANALYZER", // SM_ANALYZER
"COMMON", // SM_COMMON
"DEBUG", // SM_DEBUG
"OFFLINE_DEBUG", // SM_OFFLINE_DEBUG
"DEVICE", // SM_DEVICE
"GE_ADPT", // SM_GE_ADPT
"IR", // SM_IR
"KERNEL", // SM_KERNEL
"MD", // SM_MD
"ME", // SM_ME
"EXPRESS", // SM_EXPRESS
"OPTIMIZER", // SM_OPTIMIZER
"PARALLEL", // SM_PARALLEL
"PARSER", // SM_PARSER
"PIPELINE", // SM_PIPELINE
"PRE_ACT", // SM_PRE_ACT
"PYNATIVE", // SM_PYNATIVE
"SESSION", // SM_SESSION
"UTILS", // SM_UTILS
"VM", // SM_VM
"PROFILER", // SM_PROFILER
"PS", // SM_PS
"LITE", // SM_LITE
"HCCL_ADPT", // SM_HCCL_ADPT
"MINDQUANTUM", // SM_MINDQUANTUM
"RUNTIME_FRAMEWORK" // SM_RUNTIME_FRAMEWORK
"UNKNOWN", // SM_UNKNOWN
"CORE", // SM_CORE
"ANALYZER", // SM_ANALYZER
"COMMON", // SM_COMMON
"DEBUG", // SM_DEBUG
"OFFLINE_DEBUG", // SM_OFFLINE_DEBUG
"DEVICE", // SM_DEVICE
"GE_ADPT", // SM_GE_ADPT
"IR", // SM_IR
"KERNEL", // SM_KERNEL
"MD", // SM_MD
"ME", // SM_ME
"EXPRESS", // SM_EXPRESS
"OPTIMIZER", // SM_OPTIMIZER
"PARALLEL", // SM_PARALLEL
"PARSER", // SM_PARSER
"PIPELINE", // SM_PIPELINE
"PRE_ACT", // SM_PRE_ACT
"PYNATIVE", // SM_PYNATIVE
"SESSION", // SM_SESSION
"UTILS", // SM_UTILS
"VM", // SM_VM
"PROFILER", // SM_PROFILER
"PS", // SM_PS
"LITE", // SM_LITE
"HCCL_ADPT", // SM_HCCL_ADPT
"MINDQUANTUM", // SM_MINDQUANTUM
"RUNTIME_FRAMEWORK", // SM_RUNTIME_FRAMEWORK
"GE", // SM_GE
};
#if defined(_WIN32) || defined(_WIN64)

View File

@ -23,6 +23,7 @@ include_directories(${CMAKE_SOURCE_DIR}/mindspore/core)
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/stub/runtime/)
include_directories(${CMAKE_BINARY_DIR})
include_directories(${CMAKE_BINARY_DIR}/proto/ge)
include_directories(${CUDA_INCLUDE_DIRS})
MESSAGE("check ut_test ${CMAKE_BINARY_DIR}")
@ -104,6 +105,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"../../../mindspore/ccsrc/runtime/device/bucket.cc"
"../../../mindspore/ccsrc/runtime/device/launch_kernel.cc"
"../../../mindspore/ccsrc/runtime/device/ascend/profiling/*.cc"
"../../../mindspore/ccsrc/runtime/device/ascend/ge_runtime/*.cc"
"../../../mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc"
"../../../mindspore/ccsrc/runtime/device/ascend/ascend_launch_kernel.cc"
"../../../mindspore/ccsrc/runtime/device/ascend/ascend_launch_mul.cc"

View File

@ -0,0 +1,473 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include "common/common_test.h"
#define private public
#include "runtime/device/ascend/ge_runtime/model_runner.h"
#include "runtime/device/ascend/ge_runtime/runtime_model.h"
#include "runtime/device/ascend/ge_runtime/task/task_factory.h"
#include "runtime/device/ascend/ge_runtime/task/aicpu_task.h"
#include "runtime/device/ascend/ge_runtime/task/event_record_task.h"
#include "runtime/device/ascend/ge_runtime/task/event_wait_task.h"
#include "runtime/device/ascend/ge_runtime/task/hccl_task.h"
#include "runtime/device/ascend/ge_runtime/task/label_goto_task.h"
#include "runtime/device/ascend/ge_runtime/task/label_manager.h"
#include "runtime/device/ascend/ge_runtime/task/label_set_task.h"
#include "runtime/device/ascend/ge_runtime/task/label_switch_task.h"
#include "runtime/device/ascend/ge_runtime/task/memcpy_async_task.h"
#include "runtime/device/ascend/ge_runtime/task/profiler_task.h"
#include "runtime/device/ascend/ge_runtime/task/stream_active_task.h"
#include "runtime/device/ascend/ge_runtime/task/stream_switch_task.h"
#include "runtime/device/ascend/ge_runtime/task/tbe_task.h"
#undef private
#include "common/opskernel/ops_kernel_info_store.h"
using namespace mindspore::ge::model_runner;
using namespace testing;
class MockOpsKernelInfoStore : public ge::OpsKernelInfoStore {
public:
ge::Status Initialize(const map<string, string> &) override { return ge::SUCCESS; }
ge::Status Finalize() override { return ge::SUCCESS; }
void GetAllOpsKernelInfo(std::map<string, ge::OpInfo> &infos) const override {}
bool CheckSupported(const ge::OpDescPtr &opDescPtr, std::string &un_supported_reason) const override { return true; }
ge::Status LoadTask(ge::GETaskInfo &task) override { return ge::SUCCESS; }
};
namespace mindspore {
class TestAscendGeRuntime : public UT::Common {
public:
TestAscendGeRuntime() {}
private:
void TearDown() override {
{
std::lock_guard<std::mutex> lock(HcclTask::model_stream_mapping_mutex_);
HcclTask::model_stream_mapping_.clear();
}
}
};
TEST_F(TestAscendGeRuntime, test_task_create_null_task_info_failed) {
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
{reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(1)},
{reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
ASSERT_TRUE(TaskFactory::GetInstance().Create(model_context, nullptr) == nullptr);
}
TEST_F(TestAscendGeRuntime, test_aicpu_task_create_one_stream_success) {
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
{reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
{reinterpret_cast<rtEvent_t>(1)});
std::shared_ptr<TaskInfo> aicpu_task_info = std::make_shared<AicpuTaskInfo>(
"op_name", 0, "so_name", "kernel_name", "node_def", "ext_info", std::vector<void *>{reinterpret_cast<void *>(1)},
std::vector<void *>{reinterpret_cast<void *>(1)}, true);
std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, aicpu_task_info);
ASSERT_TRUE(std::dynamic_pointer_cast<AicpuTask>(task) != nullptr);
ASSERT_NO_THROW(task->Distribute());
}
TEST_F(TestAscendGeRuntime, test_aicpu_task_create_multi_stream_success) {
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
{reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(1)},
{reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
std::shared_ptr<TaskInfo> aicpu_task_info = std::make_shared<AicpuTaskInfo>(
"op_name", 0, "so_name", "kernel_name", "node_def", "", std::vector<void *>{reinterpret_cast<void *>(1)},
std::vector<void *>{reinterpret_cast<void *>(1)}, true);
std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, aicpu_task_info);
ASSERT_TRUE(std::dynamic_pointer_cast<AicpuTask>(task) != nullptr);
ASSERT_NO_THROW(task->Distribute());
}
TEST_F(TestAscendGeRuntime, test_aicpu_task_create_invalid_stream_id_failed) {
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
{reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(1)},
{reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
std::shared_ptr<TaskInfo> aicpu_task_info = std::make_shared<AicpuTaskInfo>(
"op_name", 5, "so_name", "kernel_name", "node_def", "", std::vector<void *>{reinterpret_cast<void *>(1)},
std::vector<void *>{reinterpret_cast<void *>(1)}, true);
ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, aicpu_task_info));
}
TEST_F(TestAscendGeRuntime, test_event_record_task_create_success) {
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
{reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
{reinterpret_cast<rtEvent_t>(1)});
std::shared_ptr<TaskInfo> event_record_task_info = std::make_shared<EventRecordTaskInfo>("op_name", 0, 0);
std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, event_record_task_info);
ASSERT_TRUE(std::dynamic_pointer_cast<EventRecordTask>(task) != nullptr);
ASSERT_NO_THROW(task->Distribute());
}
TEST_F(TestAscendGeRuntime, test_event_record_task_create_invalid_event_id_failed) {
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
{reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
{reinterpret_cast<rtEvent_t>(1)});
std::shared_ptr<TaskInfo> event_record_task_info = std::make_shared<EventRecordTaskInfo>("op_name", 0, 10);
ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, event_record_task_info));
}
TEST_F(TestAscendGeRuntime, test_event_wait_task_create_success) {
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
{reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
{reinterpret_cast<rtEvent_t>(1)});
std::shared_ptr<TaskInfo> event_record_task_info = std::make_shared<EventWaitTaskInfo>("op_name", 0, 0);
std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, event_record_task_info);
ASSERT_TRUE(std::dynamic_pointer_cast<EventWaitTask>(task) != nullptr);
ASSERT_NO_THROW(task->Distribute());
}
TEST_F(TestAscendGeRuntime, test_event_wait_task_create_invalid_event_id_failed) {
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
{reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
{reinterpret_cast<rtEvent_t>(1)});
std::shared_ptr<TaskInfo> event_record_task_info = std::make_shared<EventWaitTaskInfo>("op_name", 0, 10);
ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, event_record_task_info));
}
TEST_F(TestAscendGeRuntime, test_hccl_task_create_success) {
MockOpsKernelInfoStore ops_kernel_info_store;
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
{reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
{reinterpret_cast<rtEvent_t>(1)});
std::shared_ptr<TaskInfo> hccl_task_info = std::make_shared<HcclTaskInfo>(
"op_name", 0, "hccl_type", reinterpret_cast<void *>(1), reinterpret_cast<void *>(2), reinterpret_cast<void *>(3), 4,
5, std::vector<uint8_t>(6, 7), reinterpret_cast<void *>(&ops_kernel_info_store), 9, 10, 11, 12, "group", true);
std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, hccl_task_info);
ASSERT_TRUE(std::dynamic_pointer_cast<HcclTask>(task) != nullptr);
ASSERT_NO_THROW(task->Distribute());
}
TEST_F(TestAscendGeRuntime, test_hccl_task_create_stream_reuse_success) {
const rtModel_t model = reinterpret_cast<rtModel_t>(0x12345678);
const rtStream_t stream = reinterpret_cast<rtStream_t>(0x87654321);
constexpr uint32_t stream_id = 0;
constexpr int64_t task1_stream_num = 3;
constexpr int64_t task2_stream_num = 5;
constexpr int64_t task3_stream_num = 4;
MockOpsKernelInfoStore ops_kernel_info_store;
ModelContext model_context(0, 0, 0, model, reinterpret_cast<rtStream_t>(2), {stream},
{reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
std::shared_ptr<TaskInfo> hccl_task_info_1 = std::make_shared<HcclTaskInfo>(
"op_name", stream_id, "hccl_type", reinterpret_cast<void *>(1), reinterpret_cast<void *>(2),
reinterpret_cast<void *>(3), 4, task1_stream_num, std::vector<uint8_t>(6, 7),
reinterpret_cast<void *>(&ops_kernel_info_store), 9, 10, 11, 12, "group", true);
std::shared_ptr<TaskInfo> hccl_task_info_2 = std::make_shared<HcclTaskInfo>(
"op_name", stream_id, "hccl_type", reinterpret_cast<void *>(1), reinterpret_cast<void *>(2),
reinterpret_cast<void *>(3), 4, task2_stream_num, std::vector<uint8_t>(6, 7),
reinterpret_cast<void *>(&ops_kernel_info_store), 9, 10, 11, 12, "group", true);
std::shared_ptr<TaskInfo> hccl_task_info_3 = std::make_shared<HcclTaskInfo>(
"op_name", stream_id, "hccl_type", reinterpret_cast<void *>(1), reinterpret_cast<void *>(2),
reinterpret_cast<void *>(3), 4, task3_stream_num, std::vector<uint8_t>(6, 7),
reinterpret_cast<void *>(&ops_kernel_info_store), 9, 10, 11, 12, "group", true);
std::shared_ptr<Task> task_1 = TaskFactory::GetInstance().Create(model_context, hccl_task_info_1);
std::shared_ptr<Task> task_2 = TaskFactory::GetInstance().Create(model_context, hccl_task_info_2);
std::shared_ptr<Task> task_3 = TaskFactory::GetInstance().Create(model_context, hccl_task_info_3);
ASSERT_TRUE(std::dynamic_pointer_cast<HcclTask>(task_1) != nullptr);
ASSERT_TRUE(std::dynamic_pointer_cast<HcclTask>(task_2) != nullptr);
ASSERT_TRUE(std::dynamic_pointer_cast<HcclTask>(task_3) != nullptr);
ASSERT_NO_THROW(task_1->Distribute());
ASSERT_NO_THROW(task_2->Distribute());
ASSERT_NO_THROW(task_3->Distribute());
{
std::lock_guard<std::mutex> lock(HcclTask::model_stream_mapping_mutex_);
auto model_iter = HcclTask::model_stream_mapping_.find(model);
ASSERT_NE(model_iter, HcclTask::model_stream_mapping_.end());
auto stream_iter = model_iter->second.find(stream_id);
ASSERT_NE(stream_iter, model_iter->second.end());
const auto &stream_vec = stream_iter->second;
ASSERT_EQ(stream_vec.size(), std::max(task1_stream_num, std::max(task2_stream_num, task3_stream_num)));
for (const auto &s : stream_vec) {
auto shared = s.lock();
ASSERT_TRUE(shared != nullptr);
}
}
}
TEST_F(TestAscendGeRuntime, test_label_goto_task_create_success) {
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
{reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
{reinterpret_cast<rtEvent_t>(1)});
std::shared_ptr<TaskInfo> label_goto_task_info = std::make_shared<LabelGotoTaskInfo>("op_name", 0, 0);
std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, label_goto_task_info);
auto label_goto_task = std::dynamic_pointer_cast<LabelGotoTask>(task);
ASSERT_TRUE(label_goto_task != nullptr);
ASSERT_NO_THROW(task->Distribute());
label_goto_task->index_value_ = new uint8_t[5];
}
TEST_F(TestAscendGeRuntime, test_label_goto_task_create_invalid_label_id_failed) {
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
{reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
{reinterpret_cast<rtEvent_t>(1)});
std::shared_ptr<TaskInfo> label_goto_task_info = std::make_shared<LabelGotoTaskInfo>("op_name", 0, 1);
ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, label_goto_task_info));
}
TEST_F(TestAscendGeRuntime, test_label_goto_task_reuse_success) {
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
{reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
{reinterpret_cast<rtEvent_t>(1)});
std::shared_ptr<TaskInfo> label_goto_task_info = std::make_shared<LabelGotoTaskInfo>("op_name", 0, 0);
std::shared_ptr<Task> task1 = TaskFactory::GetInstance().Create(model_context, label_goto_task_info);
std::shared_ptr<Task> task2 = TaskFactory::GetInstance().Create(model_context, label_goto_task_info);
auto label_goto_task_1 = std::dynamic_pointer_cast<LabelGotoTask>(task1);
auto label_goto_task_2 = std::dynamic_pointer_cast<LabelGotoTask>(task2);
ASSERT_TRUE(label_goto_task_1 != nullptr);
ASSERT_NO_THROW(task1->Distribute());
ASSERT_TRUE(label_goto_task_2 != nullptr);
ASSERT_NO_THROW(task2->Distribute());
ASSERT_EQ(label_goto_task_1->label_info_, label_goto_task_2->label_info_);
}
TEST_F(TestAscendGeRuntime, test_label_set_task_create_success) {
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
{reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
{reinterpret_cast<rtEvent_t>(1)});
std::shared_ptr<TaskInfo> label_set_task_info = std::make_shared<LabelSetTaskInfo>("op_name", 0, 0);
std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, label_set_task_info);
ASSERT_TRUE(std::dynamic_pointer_cast<LabelSetTask>(task) != nullptr);
ASSERT_NO_THROW(task->Distribute());
}
TEST_F(TestAscendGeRuntime, test_label_set_task_create_invalid_label_id_failed) {
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
{reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)},
{reinterpret_cast<rtEvent_t>(1)});
std::shared_ptr<TaskInfo> label_set_task_info = std::make_shared<LabelGotoTaskInfo>("op_name", 0, 1);
ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, label_set_task_info));
}
TEST_F(TestAscendGeRuntime, test_label_switch_task_create_success) {
ModelContext model_context(
0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)},
{reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
std::shared_ptr<TaskInfo> label_switch_task_info =
std::make_shared<LabelSwitchTaskInfo>("op_name", 0, 2, std::vector<uint32_t>{0, 1}, reinterpret_cast<void *>(1));
std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, label_switch_task_info);
ASSERT_TRUE(std::dynamic_pointer_cast<LabelSwitchTask>(task) != nullptr);
ASSERT_NO_THROW(task->Distribute());
}
TEST_F(TestAscendGeRuntime, test_label_switch_task_create_invalid_stream_id_failed) {
ModelContext model_context(
0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)},
{reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
std::shared_ptr<TaskInfo> label_switch_task_info =
std::make_shared<LabelSwitchTaskInfo>("op_name", 1, 2, std::vector<uint32_t>{0, 1}, reinterpret_cast<void *>(1));
ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, label_switch_task_info));
}
TEST_F(TestAscendGeRuntime, test_label_switch_task_create_invalid_label_id_failed) {
ModelContext model_context(
0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)},
{reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
std::shared_ptr<TaskInfo> label_switch_task_info =
std::make_shared<LabelSwitchTaskInfo>("op_name", 0, 3, std::vector<uint32_t>{0, 1, 2}, reinterpret_cast<void *>(1));
ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, label_switch_task_info));
}
TEST_F(TestAscendGeRuntime, test_label_switch_task_reuse_success) {
ModelContext model_context(
0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)},
{reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
std::shared_ptr<TaskInfo> label_switch_task_info =
std::make_shared<LabelSwitchTaskInfo>("op_name", 0, 2, std::vector<uint32_t>{0, 1}, reinterpret_cast<void *>(1));
std::shared_ptr<Task> task1 = TaskFactory::GetInstance().Create(model_context, label_switch_task_info);
std::shared_ptr<Task> task2 = TaskFactory::GetInstance().Create(model_context, label_switch_task_info);
auto label_switch_task_1 = std::dynamic_pointer_cast<LabelSwitchTask>(task1);
auto label_switch_task_2 = std::dynamic_pointer_cast<LabelSwitchTask>(task2);
ASSERT_TRUE(label_switch_task_1 != nullptr);
ASSERT_TRUE(label_switch_task_2 != nullptr);
ASSERT_NO_THROW(task1->Distribute());
ASSERT_NO_THROW(task2->Distribute());
ASSERT_EQ(label_switch_task_1->label_info_, label_switch_task_2->label_info_);
}
TEST_F(TestAscendGeRuntime, test_memcpy_async_task_create_success) {
ModelContext model_context(
0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)},
{reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
std::shared_ptr<TaskInfo> memcpy_task_info = std::make_shared<MemcpyAsyncTaskInfo>(
"op_name", 0, reinterpret_cast<void *>(1), 2, reinterpret_cast<void *>(3), 4, 5, true);
std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, memcpy_task_info);
ASSERT_TRUE(std::dynamic_pointer_cast<MemcpyAsyncTask>(task) != nullptr);
ASSERT_NO_THROW(task->Distribute());
}
TEST_F(TestAscendGeRuntime, test_memcpy_async_task_create_invalid_stream_id_failed) {
ModelContext model_context(
0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)},
{reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
std::shared_ptr<TaskInfo> memcpy_task_info = std::make_shared<MemcpyAsyncTaskInfo>(
"op_name", 1, reinterpret_cast<void *>(1), 2, reinterpret_cast<void *>(3), 4, 5, true);
ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, memcpy_task_info));
}
TEST_F(TestAscendGeRuntime, test_profiler_task_create_success) {
ModelContext model_context(
0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)},
{reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
std::shared_ptr<TaskInfo> profiler_task_info = std::make_shared<ProfilerTraceTaskInfo>("op_name", 0, 1, true, 2);
std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, profiler_task_info);
ASSERT_TRUE(std::dynamic_pointer_cast<ProfilerTask>(task) != nullptr);
ASSERT_NO_THROW(task->Distribute());
}
TEST_F(TestAscendGeRuntime, test_profiler_task_create_invalid_stream_id_failed) {
ModelContext model_context(
0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)},
{reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)});
std::shared_ptr<TaskInfo> profiler_task_info = std::make_shared<ProfilerTraceTaskInfo>("op_name", 1, 1, true, 2);
ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, profiler_task_info));
}
TEST_F(TestAscendGeRuntime, test_stream_active_task_create_success) {
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
{reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)},
{reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)},
{reinterpret_cast<rtEvent_t>(1)});
std::shared_ptr<TaskInfo> stream_active_task_info = std::make_shared<StreamActiveTaskInfo>("op_name", 0, 1);
std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, stream_active_task_info);
ASSERT_TRUE(std::dynamic_pointer_cast<StreamActiveTask>(task) != nullptr);
ASSERT_NO_THROW(task->Distribute());
}
TEST_F(TestAscendGeRuntime, test_stream_active_task_create_invalid_active_stream_id_failed) {
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
{reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)},
{reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)},
{reinterpret_cast<rtEvent_t>(1)});
std::shared_ptr<TaskInfo> stream_active_task_info = std::make_shared<StreamActiveTaskInfo>("op_name", 0, 2);
ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, stream_active_task_info));
}
TEST_F(TestAscendGeRuntime, test_stream_switch_task_create_success) {
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
{reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)},
{reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)},
{reinterpret_cast<rtEvent_t>(1)});
std::shared_ptr<TaskInfo> stream_switch_task_info = std::make_shared<StreamSwitchTaskInfo>(
"op_name", 0, 1, reinterpret_cast<void *>(2), reinterpret_cast<void *>(3), 4, 5);
std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, stream_switch_task_info);
ASSERT_TRUE(std::dynamic_pointer_cast<StreamSwitchTask>(task) != nullptr);
ASSERT_NO_THROW(task->Distribute());
}
TEST_F(TestAscendGeRuntime, test_stream_switch_task_create_invalid_true_stream_id_failed) {
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
{reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)},
{reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)},
{reinterpret_cast<rtEvent_t>(1)});
std::shared_ptr<TaskInfo> stream_switch_task_info = std::make_shared<StreamSwitchTaskInfo>(
"op_name", 0, 2, reinterpret_cast<void *>(2), reinterpret_cast<void *>(3), 4, 5);
std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, stream_switch_task_info);
ASSERT_TRUE(std::dynamic_pointer_cast<StreamSwitchTask>(task) != nullptr);
ASSERT_ANY_THROW(task->Distribute());
}
TEST_F(TestAscendGeRuntime, test_stream_switch_task_create_invalid_stream_id_failed) {
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
{reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)},
{reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)},
{reinterpret_cast<rtEvent_t>(1)});
std::shared_ptr<TaskInfo> stream_switch_task_info = std::make_shared<StreamSwitchTaskInfo>(
"op_name", 2, 1, reinterpret_cast<void *>(2), reinterpret_cast<void *>(3), 4, 5);
ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, stream_switch_task_info));
}
TEST_F(TestAscendGeRuntime, test_tbe_task_create_success) {
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
{reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)},
{reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)},
{reinterpret_cast<rtEvent_t>(1)});
std::shared_ptr<TaskInfo> tbe_task_info = std::make_shared<TbeTaskInfo>(
"op_name", 0, "stub_func", 1, std::vector<uint8_t>(100, 2), 100, std::vector<uint8_t>{5, 6},
reinterpret_cast<void *>(7), 8, std::vector<uint8_t>{9, 10},
std::vector<void *>{reinterpret_cast<void *>(11), reinterpret_cast<void *>(12)},
std::vector<void *>{reinterpret_cast<void *>(13), reinterpret_cast<void *>(14)},
std::vector<void *>{reinterpret_cast<void *>(15), reinterpret_cast<void *>(16)}, true);
std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, tbe_task_info);
auto tbe_task = std::dynamic_pointer_cast<TbeTask>(task);
ASSERT_TRUE(tbe_task != nullptr);
ASSERT_NO_THROW(task->Distribute());
tbe_task->args_ = new uint8_t[5];
}
TEST_F(TestAscendGeRuntime, test_tbe_task_create_invalid_stream_id_failed) {
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
{reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)},
{reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)},
{reinterpret_cast<rtEvent_t>(1)});
std::shared_ptr<TaskInfo> tbe_task_info = std::make_shared<TbeTaskInfo>(
"op_name", 3, "stub_func", 1, std::vector<uint8_t>(100, 2), 100, std::vector<uint8_t>{5, 6},
reinterpret_cast<void *>(7), 8, std::vector<uint8_t>{9, 10},
std::vector<void *>{reinterpret_cast<void *>(11), reinterpret_cast<void *>(12)},
std::vector<void *>{reinterpret_cast<void *>(13), reinterpret_cast<void *>(14)},
std::vector<void *>{reinterpret_cast<void *>(15), reinterpret_cast<void *>(16)}, true);
ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, tbe_task_info));
}
TEST_F(TestAscendGeRuntime, test_tbe_task_create_empty_stub_func_failed) {
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
{reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)},
{reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)},
{reinterpret_cast<rtEvent_t>(1)});
std::shared_ptr<TaskInfo> tbe_task_info = std::make_shared<TbeTaskInfo>(
"op_name", 0, "", 1, std::vector<uint8_t>(100, 2), 100, std::vector<uint8_t>{5, 6}, reinterpret_cast<void *>(7), 8,
std::vector<uint8_t>{9, 10}, std::vector<void *>{reinterpret_cast<void *>(11), reinterpret_cast<void *>(12)},
std::vector<void *>{reinterpret_cast<void *>(13), reinterpret_cast<void *>(14)},
std::vector<void *>{reinterpret_cast<void *>(15), reinterpret_cast<void *>(16)}, true);
std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, tbe_task_info);
ASSERT_TRUE(std::dynamic_pointer_cast<TbeTask>(task) != nullptr);
ASSERT_ANY_THROW(task->Distribute());
}
TEST_F(TestAscendGeRuntime, test_model_runner_success) {
constexpr uint32_t model_id = 0;
ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2),
{reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)},
{reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)},
{reinterpret_cast<rtEvent_t>(1)});
std::shared_ptr<TaskInfo> tbe_task_info = std::make_shared<TbeTaskInfo>(
"op_name", 0, "stub_func", 1, std::vector<uint8_t>(100, 2), 100, std::vector<uint8_t>{5, 6},
reinterpret_cast<void *>(7), 8, std::vector<uint8_t>{9, 10},
std::vector<void *>{reinterpret_cast<void *>(11), reinterpret_cast<void *>(12)},
std::vector<void *>{reinterpret_cast<void *>(13), reinterpret_cast<void *>(14)},
std::vector<void *>{reinterpret_cast<void *>(15), reinterpret_cast<void *>(16)}, true);
std::shared_ptr<TaskInfo> aicpu_task_info = std::make_shared<AicpuTaskInfo>(
"op_name", 0, "so_name", "kernel_name", "node_def", "ext_info", std::vector<void *>{reinterpret_cast<void *>(1)},
std::vector<void *>{reinterpret_cast<void *>(1)}, true);
auto davice_model =
std::make_shared<DavinciModel>(std::vector<std::shared_ptr<TaskInfo>>{tbe_task_info, aicpu_task_info},
std::vector<uint32_t>{}, std::vector<uint32_t>{}, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0);
ASSERT_NO_THROW(ModelRunner::Instance().LoadDavinciModel(0, 0, model_id, davice_model));
auto iter = ModelRunner::Instance().runtime_models_.find(model_id);
ASSERT_TRUE(iter != ModelRunner::Instance().runtime_models_.end());
auto &task_list = iter->second->task_list_;
task_list.clear();
ASSERT_NO_THROW(task_list.emplace_back(TaskFactory::GetInstance().Create(model_context, tbe_task_info)));
ASSERT_NO_THROW(task_list.emplace_back(TaskFactory::GetInstance().Create(model_context, aicpu_task_info)));
ASSERT_NO_THROW(ModelRunner::Instance().DistributeTask(model_id));
ASSERT_NO_THROW(ModelRunner::Instance().LoadModelComplete(model_id));
ASSERT_NO_THROW(ModelRunner::Instance().RunModel(model_id));
ASSERT_FALSE(ModelRunner::Instance().GetTaskIdList(model_id).empty());
ASSERT_FALSE(ModelRunner::Instance().GetStreamIdList(model_id).empty());
ASSERT_FALSE(ModelRunner::Instance().GetRuntimeInfoMap(model_id).empty());
ASSERT_NO_THROW(ModelRunner::Instance().GetModelHandle(model_id));
ASSERT_NO_THROW(ModelRunner::Instance().UnloadModel(model_id));
}
} // namespace mindspore

View File

@ -14,51 +14,8 @@
* limitations under the License.
*/
#include <vector>
#include "framework/ge_runtime/model_runner.h"
#include "runtime/hccl_adapter/hccl_adapter.h"
namespace ge {
namespace model_runner {
ModelRunner &ModelRunner::Instance() {
static ModelRunner runner;
return runner;
}
bool ModelRunner::LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint32_t model_id,
std::shared_ptr<DavinciModel> ascend_model,
std::shared_ptr<ge::ModelListener> listener) {
return true;
}
bool ModelRunner::UnloadModel(uint32_t model_id) { return true; }
bool ModelRunner::LoadModelComplete(uint32_t model_id) { return true; }
bool ModelRunner::RunModel(uint32_t model_id, const ge::InputData &input_data, ge::OutputData *output_data) {
return true;
}
void *ModelRunner::GetModelHandle(uint32_t model_id) const { return nullptr; }
bool ModelRunner::DistributeTask(uint32_t model_id) { return true; }
const std::vector<uint32_t> &ModelRunner::GetTaskIdList(uint32_t model_id) const {
static std::vector<uint32_t> task_id_list;
return task_id_list;
}
const std::vector<uint32_t> &ModelRunner::GetStreamIdList(uint32_t model_id) const {
static std::vector<uint32_t> stream_id_list;
return stream_id_list;
}
const std::map<std::string, std::shared_ptr<RuntimeInfo>> &ModelRunner::GetRuntimeInfoMap(uint32_t model_id) const {
static std::map<std::string, std::shared_ptr<RuntimeInfo>> runtime_info_map;
return runtime_info_map;
}
} // namespace model_runner
} // namespace ge
namespace mindspore {
namespace hccl {
bool InitHccl(uint32_t, std::string_view, std::string_view) { return true; }

View File

@ -141,9 +141,9 @@ rtError_t rtGetFunctionByName(const char *stubName, void **stubFunc) { return RT
rtError_t rtSetTaskGenCallback(rtTaskGenCallback callback) { return RT_ERROR_NONE; }
RTS_API rtError_t rtProfilerStart(uint64_t profConfig, int32_t numsDev, uint32_t* deviceList) { return RT_ERROR_NONE; }
RTS_API rtError_t rtProfilerStart(uint64_t profConfig, int32_t numsDev, uint32_t *deviceList) { return RT_ERROR_NONE; }
RTS_API rtError_t rtProfilerStop(uint64_t profConfig, int32_t numsDev, uint32_t* deviceList) { return RT_ERROR_NONE; }
RTS_API rtError_t rtProfilerStop(uint64_t profConfig, int32_t numsDev, uint32_t *deviceList) { return RT_ERROR_NONE; }
int AdxDataDumpServerInit() { return 0; }
@ -151,11 +151,13 @@ int AdxDataDumpServerUnInit() { return 0; }
RTS_API rtError_t rtGetTaskIdAndStreamID(uint32_t *taskid, uint32_t *streamid) { return RT_ERROR_NONE; }
RTS_API rtError_t rtSetTaskFailCallback(rtTaskFailCallback callback) {return RT_ERROR_NONE; }
RTS_API rtError_t rtSetTaskFailCallback(rtTaskFailCallback callback) { return RT_ERROR_NONE; }
RTS_API rtError_t rtRegDeviceStateCallback(const char *regName, rtDeviceStateCallback callback) {return RT_ERROR_NONE; }
RTS_API rtError_t rtRegDeviceStateCallback(const char *regName, rtDeviceStateCallback callback) {
return RT_ERROR_NONE;
}
RTS_API rtError_t rtSetMsprofReporterCallback(MsprofReporterCallback callback) {return RT_ERROR_NONE; }
RTS_API rtError_t rtSetMsprofReporterCallback(MsprofReporterCallback callback) { return RT_ERROR_NONE; }
RTS_API rtError_t rtRegTaskFailCallbackByModule(const char *moduleName, rtTaskFailCallback callback) {
return RT_ERROR_NONE;
@ -168,3 +170,28 @@ RTS_API rtError_t rtDevBinaryUnRegister(void *handle) { return RT_ERROR_NONE; }
RTS_API rtError_t rtMemsetAsync(void *ptr, uint64_t destMax, uint32_t value, uint64_t count, rtStream_t stream) {
return RT_ERROR_NONE;
}
RTS_API rtError_t rtLabelListCpy(rtLabel_t *label, uint32_t labelNumber, void *dst, uint32_t dstMax) {
return RT_ERROR_NONE;
}
RTS_API rtError_t rtModelGetTaskId(rtModel_t model, uint32_t *taskid, uint32_t *streamid) { return RT_ERROR_NONE; }
RTS_API rtError_t rtLabelCreateEx(rtLabel_t *label, rtStream_t stream) { return RT_ERROR_NONE; }
RTS_API rtError_t rtCpuKernelLaunchWithFlag(const void *soName, const void *kernelName, uint32_t blockDim,
const void *args, uint32_t argsSize, rtSmDesc_t *smDesc, rtStream_t stream,
uint32_t flags) {
return RT_ERROR_NONE;
}
RTS_API rtError_t rtLabelSwitchByIndex(void *ptr, uint32_t max, void *labelInfoPtr, rtStream_t stream) {
return RT_ERROR_NONE;
}
RTS_API rtError_t rtProfilerTrace(uint64_t id, bool notify, uint32_t flags, rtStream_t stream) { return RT_ERROR_NONE; }
RTS_API rtError_t rtKernelLaunchWithFlag(const void *stubFunc, uint32_t blockDim, void *args, uint32_t argsSize,
rtSmDesc_t *smDesc, rtStream_t stream, uint32_t flags) {
return RT_ERROR_NONE;
}