New Device Context
This commit is contained in:
parent
a300990322
commit
c1ad42df74
|
@ -17,6 +17,7 @@
|
|||
"mindspore/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_kernel_runtime.cc" "build/include_what_you_use"
|
||||
"mindspore/mindspore/ccsrc/utils/convert_utils_py.cc" "whitespace/indent"
|
||||
"mindspore/mindspore/core/utils/log_adapter.cc" "runtime/references"
|
||||
"mindspore/mindspore/ccsrc/runtime/hardware/device_context.h" "readability/braces"
|
||||
|
||||
# Modelzoo
|
||||
"mindspore/model_zoo/official/cv/yolov4_tiny/infer/mxbase/src/Yolov4TinyDetection.h" "runtime/references"
|
||||
|
|
|
@ -2837,7 +2837,13 @@ void SessionBasic::InitAllBucket(const KernelGraphPtr &graph, const device::Devi
|
|||
MS_LOG(INFO) << "Create new bucket:" << bucket_id << " size:" << bucket_size;
|
||||
std::shared_ptr<device::Bucket> bucket = nullptr;
|
||||
if (device_context != nullptr) {
|
||||
bucket = device_context->CreateBucket(bucket_id++, bucket_size);
|
||||
auto deprecated_kernel_executor =
|
||||
dynamic_cast<device::DeprecatedKernelExecutor *>(device_context->kernel_executor_.get());
|
||||
if (deprecated_kernel_executor != nullptr) {
|
||||
bucket = deprecated_kernel_executor->CreateBucket(bucket_id++, bucket_size);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Not Support CreateBucket() in Device Context.";
|
||||
}
|
||||
} else {
|
||||
bucket = CreateBucket(bucket_id++, bucket_size);
|
||||
}
|
||||
|
@ -2871,8 +2877,14 @@ void SessionBasic::DoAllReduceOnGrads(const std::string &actor_info, const std::
|
|||
std::shared_ptr<device::Bucket> bucket;
|
||||
auto iter = actor_set_to_bucket_.find(actor_info);
|
||||
if (iter == actor_set_to_bucket_.end()) {
|
||||
static size_t bucket_id = 0;
|
||||
bucket = device_context->CreateBucket(bucket_id++, outputs.size());
|
||||
auto deprecated_kernel_executor =
|
||||
dynamic_cast<device::DeprecatedKernelExecutor *>(device_context->kernel_executor_.get());
|
||||
if (deprecated_kernel_executor != nullptr) {
|
||||
static size_t bucket_id = 0;
|
||||
bucket = deprecated_kernel_executor->CreateBucket(bucket_id++, outputs.size());
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Not Support CreateBucket() in Device Context.";
|
||||
}
|
||||
actor_set_to_bucket_[actor_info] = bucket;
|
||||
} else {
|
||||
bucket = iter->second;
|
||||
|
@ -2987,7 +2999,11 @@ void SessionBasic::DumpGraphs(const std::vector<KernelGraphPtr> &graphs) {
|
|||
uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
const auto &device_context =
|
||||
device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_target, device_id});
|
||||
rank_id = device_context->GetRankID();
|
||||
auto deprecated_kernel_executor =
|
||||
dynamic_cast<device::DeprecatedKernelExecutor *>(device_context->kernel_executor_.get());
|
||||
if (deprecated_kernel_executor != nullptr) {
|
||||
rank_id = deprecated_kernel_executor->GetRankID();
|
||||
}
|
||||
}
|
||||
std::string final_graph = "trace_code_graph_" + std::to_string(graph->graph_id());
|
||||
if (json_parser.e2e_dump_enabled() || json_parser.async_dump_enabled()) {
|
||||
|
|
|
@ -338,9 +338,9 @@ device::DeviceAddressPtr CloneEmptyDeviceAddress(const device::DeviceAddressPtr
|
|||
const DeviceContext *device_context) {
|
||||
MS_EXCEPTION_IF_NULL(old_device_address);
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
auto new_device_address =
|
||||
device_context->CreateDeviceAddress(nullptr, old_device_address->GetSize(), old_device_address->format(),
|
||||
old_device_address->type_id(), old_device_address->host_shape());
|
||||
auto new_device_address = device_context->device_res_manager_->CreateDeviceAddress(
|
||||
nullptr, old_device_address->GetSize(), old_device_address->format(), old_device_address->type_id(),
|
||||
old_device_address->host_shape());
|
||||
MS_EXCEPTION_IF_NULL(new_device_address);
|
||||
new_device_address->set_original_ref_count(old_device_address->original_ref_count());
|
||||
new_device_address->ResetRefCount();
|
||||
|
@ -1300,7 +1300,7 @@ void MindRTBackend::SyncStream() {
|
|||
const auto &device_context =
|
||||
device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name_, device_id_});
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
(void)device_context->SyncStream();
|
||||
(void)device_context->device_res_manager_->SyncStream();
|
||||
}
|
||||
|
||||
std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(const FuncGraphPtr &root_graph) {
|
||||
|
|
|
@ -475,7 +475,12 @@ uint32_t Debugger::GetRankID() {
|
|||
uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
const auto &device_context =
|
||||
device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_target, device_id});
|
||||
uint32_t rank_id = device_context->GetRankID();
|
||||
uint32_t rank_id = 0;
|
||||
auto deprecated_kernel_executor =
|
||||
dynamic_cast<device::DeprecatedKernelExecutor *>(device_context->kernel_executor_.get());
|
||||
if (deprecated_kernel_executor != nullptr) {
|
||||
rank_id = deprecated_kernel_executor->GetRankID();
|
||||
}
|
||||
return rank_id;
|
||||
}
|
||||
|
||||
|
|
|
@ -90,8 +90,8 @@ void LoadInputs(const CNodePtr &cnode, const KernelLaunchInfo *launch_info, uint
|
|||
auto host_format = kOpFormat_DEFAULT;
|
||||
auto device_format =
|
||||
E2eDump::IsDeviceTargetGPU() ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(input_kernel, PARAMETER_OUTPUT_INDEX);
|
||||
auto device_addr =
|
||||
device_context->CreateDeviceAddress(addr->addr, addr->size, device_format, device_type, ShapeVector());
|
||||
auto device_addr = device_context->device_res_manager_->CreateDeviceAddress(addr->addr, addr->size, device_format,
|
||||
device_type, ShapeVector());
|
||||
string input_tensor_name = input_kernel_name + ':' + "0";
|
||||
ShapeVector int_shapes;
|
||||
GetDumpIntShape(input_kernel, PARAMETER_OUTPUT_INDEX, NOT_NULL(&int_shapes), trans_flag);
|
||||
|
@ -132,8 +132,8 @@ void LoadOutputs(const CNodePtr &cnode, const KernelLaunchInfo *launch_info, uin
|
|||
|
||||
auto host_format = kOpFormat_DEFAULT;
|
||||
auto device_format = E2eDump::IsDeviceTargetGPU() ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(cnode, j);
|
||||
auto device_addr =
|
||||
device_context->CreateDeviceAddress(addr->addr, addr->size, device_format, device_type, ShapeVector());
|
||||
auto device_addr = device_context->device_res_manager_->CreateDeviceAddress(addr->addr, addr->size, device_format,
|
||||
device_type, ShapeVector());
|
||||
string tensor_name = kernel_name + ':' + std::to_string(j);
|
||||
ShapeVector int_shapes;
|
||||
GetDumpIntShape(cnode, j, NOT_NULL(&int_shapes), trans_flag);
|
||||
|
|
|
@ -310,11 +310,11 @@ bool CollectiveManager::InitHostCommlib() {
|
|||
device::DeviceContextKey host_key = {"CPU", 0};
|
||||
host_ctx_ = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(host_key);
|
||||
MS_EXCEPTION_IF_NULL(host_ctx_);
|
||||
if (!host_ctx_->LoadCollectiveCommLib()) {
|
||||
if (!host_ctx_->device_res_manager_->LoadCollectiveCommLib()) {
|
||||
MS_LOG(ERROR) << "Failed to load communication library on the host side.";
|
||||
return false;
|
||||
}
|
||||
host_comm_lib_instance_ = host_ctx_->collective_comm_lib();
|
||||
host_comm_lib_instance_ = host_ctx_->device_res_manager_->collective_comm_lib();
|
||||
MS_EXCEPTION_IF_NULL(host_comm_lib_instance_);
|
||||
|
||||
// For some communication libraries, global_rank_id_', 'global_rank_size_' should be set by caller, e.g., when using
|
||||
|
@ -360,11 +360,11 @@ bool CollectiveManager::InitDeviceCommLib() {
|
|||
// We can initialize device context now because device id(local_rank_id_) is already assigned.
|
||||
device_ctx_->Initialize();
|
||||
|
||||
if (!device_ctx_->LoadCollectiveCommLib()) {
|
||||
if (!device_ctx_->device_res_manager_->LoadCollectiveCommLib()) {
|
||||
MS_LOG(ERROR) << "Failed to load communication library on the device side.";
|
||||
return false;
|
||||
}
|
||||
device_comm_lib_instance_ = device_ctx_->collective_comm_lib();
|
||||
device_comm_lib_instance_ = device_ctx_->device_res_manager_->collective_comm_lib();
|
||||
MS_EXCEPTION_IF_NULL(device_comm_lib_instance_);
|
||||
|
||||
MS_LOG(INFO) << "Start initializing communication library on device side...";
|
||||
|
|
|
@ -122,7 +122,7 @@ void EmbeddingCacheTableManager::AllocMemForEmbeddingCacheTable(const device::De
|
|||
size_t embedding_size = item.second.embedding_size;
|
||||
auto &device_address = item.second.device_address;
|
||||
device_address.size = device_cache_size_ * embedding_size * sizeof(float);
|
||||
auto addr = device_context->AllocateMemory(device_address.size);
|
||||
auto addr = device_context->device_res_manager_->AllocateMemory(device_address.size);
|
||||
MS_EXCEPTION_IF_NULL(addr);
|
||||
device_address.addr = addr;
|
||||
|
||||
|
@ -141,10 +141,10 @@ void EmbeddingCacheTableManager::AllocMemForEmbeddingCacheTable(const device::De
|
|||
MS_EXCEPTION_IF_NULL(embedding_host_cache_);
|
||||
|
||||
embedding_device_cache_->hash_swap_index_addr_ =
|
||||
reinterpret_cast<int *>(device_context->AllocateMemory(batch_ids_num_ * sizeof(int)));
|
||||
reinterpret_cast<int *>(device_context->device_res_manager_->AllocateMemory(batch_ids_num_ * sizeof(int)));
|
||||
MS_EXCEPTION_IF_NULL(embedding_device_cache_->hash_swap_index_addr_);
|
||||
embedding_device_cache_->hash_swap_value_addr_ =
|
||||
reinterpret_cast<float *>(device_context->AllocateMemory(max_embedding_size * batch_ids_num_ * sizeof(float)));
|
||||
embedding_device_cache_->hash_swap_value_addr_ = reinterpret_cast<float *>(
|
||||
device_context->device_res_manager_->AllocateMemory(max_embedding_size * batch_ids_num_ * sizeof(float)));
|
||||
MS_EXCEPTION_IF_NULL(embedding_device_cache_->hash_swap_value_addr_);
|
||||
}
|
||||
|
||||
|
|
|
@ -133,7 +133,7 @@ void RecoveryContext::ObtainGlobalLatestCkptInfo() {
|
|||
device::DeviceContextKey host_key = {"CPU", 0};
|
||||
device::DeviceContext *host_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(host_key);
|
||||
MS_EXCEPTION_IF_NULL(host_context);
|
||||
device::CollectiveCommunicationLib *host_comm_lib_instance = host_context->collective_comm_lib();
|
||||
device::CollectiveCommunicationLib *host_comm_lib_instance = host_context->device_res_manager_->collective_comm_lib();
|
||||
MS_EXCEPTION_IF_NULL(host_comm_lib_instance);
|
||||
|
||||
if (global_rank_id_ >= global_rank_size_) {
|
||||
|
|
|
@ -71,7 +71,7 @@ class Environ {
|
|||
const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
|
||||
{value.second->device_name_, value.second->device_id_});
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
device_context->FreeMemory(value.second->addr_);
|
||||
device_context->device_res_manager_->FreeMemory(value.second->addr_);
|
||||
}
|
||||
|
||||
values_.clear();
|
||||
|
|
|
@ -42,7 +42,7 @@ BlockQueueStatus_T AscendDataQueueDynamic::Push(std::vector<DataQueueItem> data)
|
|||
MS_LOG(ERROR) << "Invalid Input: ptr: " << item.data_ptr_ << ", len: " << item.data_len_;
|
||||
return ERROR_INPUT;
|
||||
}
|
||||
void *addr = device_context_->AllocateMemory(item.data_len_);
|
||||
void *addr = device_context_->device_res_manager_->AllocateMemory(item.data_len_);
|
||||
if (addr == nullptr) {
|
||||
MS_LOG(ERROR) << "Allocate device memory of data queue failed";
|
||||
}
|
||||
|
|
|
@ -184,7 +184,7 @@ void AscendDeviceAddress::BindDevice() const {
|
|||
device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name_, device_id_});
|
||||
auto ascend_device_context = dynamic_cast<AscendDeviceContext *>(device_context);
|
||||
MS_EXCEPTION_IF_NULL(ascend_device_context);
|
||||
ascend_device_context->BindDeviceToCurrentThread();
|
||||
ascend_device_context->device_res_manager_->BindDeviceToCurrentThread();
|
||||
} else {
|
||||
MS_LOG(WARNING) << "device name is null.";
|
||||
}
|
||||
|
|
|
@ -15,244 +15,18 @@
|
|||
*/
|
||||
|
||||
#include "plugin/device/ascend/hal/hardware/ascend_device_context.h"
|
||||
#include <algorithm>
|
||||
#include <set>
|
||||
#include <unordered_map>
|
||||
#include "acl/acl_rt.h"
|
||||
#include "runtime/dev.h"
|
||||
#include "plugin/device/ascend/optimizer/ascend_backend_optimization.h"
|
||||
#include "common/graph_kernel/adapter/graph_kernel_optimization.h"
|
||||
#include "common/graph_kernel/graph_kernel_flags.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "include/common/utils/parallel_context.h"
|
||||
#include "plugin/device/ascend/hal/device/kernel_select_ascend.h"
|
||||
#include "runtime/device/kernel_adjust.h"
|
||||
#include "runtime/device/memory_manager.h"
|
||||
#include "plugin/device/ascend/hal/device/ascend_stream_assign.h"
|
||||
#include "plugin/device/ascend/hal/device/kernel_build_ascend.h"
|
||||
#include "plugin/device/ascend/hal/hardware/ascend_graph_optimization.h"
|
||||
#include "kernel/ascend_kernel_mod.h"
|
||||
#include "kernel/common_utils.h"
|
||||
#include "plugin/device/ascend/kernel/aicpu/aicpu_kernel_load.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_compile.h"
|
||||
#include "plugin/device/ascend/hal/device/ascend_bucket.h"
|
||||
#include "common/util/error_manager/error_manager.h"
|
||||
#include "plugin/device/ascend/hal/device/ascend_memory_adapter.h"
|
||||
#include "runtime/data_queue/data_queue_mgr.h"
|
||||
#include "backend/common/optimizer/common_backend_optimization.h"
|
||||
#ifndef ENABLE_SECURITY
|
||||
#include "debug/data_dump/dump_json_parser.h"
|
||||
#include "toolchain/adx_datadump_server.h"
|
||||
#include "toolchain/adx_datadump_callback.h"
|
||||
#include "include/common/debug/anf_ir_dump.h"
|
||||
#include "include/common/debug/dump_proto.h"
|
||||
#include "debug/data_dump/e2e_dump.h"
|
||||
#include "debug/debugger/debugger_utils.h"
|
||||
#endif
|
||||
#include "plugin/device/ascend/hal/hardware/ascend_utils.h"
|
||||
#ifdef ENABLE_DEBUGGER
|
||||
#include "debug/tensor_load.h"
|
||||
#include "debug/debugger/proto_exporter.h"
|
||||
#else
|
||||
#include "debug/debugger/proto_exporter_stub.h"
|
||||
#endif
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
#include "include/common/debug/rdr/recorder_manager.h"
|
||||
#include "debug/rdr/graph_recorder.h"
|
||||
#endif
|
||||
|
||||
#ifndef ENABLE_SECURITY
|
||||
#include "profiler/device/ascend/memory_profiling.h"
|
||||
#include "plugin/device/ascend/hal/device/profiling/profiling_manager.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "profiler/device/ascend/pynative_profiling.h"
|
||||
#include "profiler/device/ascend/ascend_profiling.h"
|
||||
|
||||
using Adx::AdxRegDumpProcessCallBack;
|
||||
using mindspore::device::ascend::ProfilingManager;
|
||||
using mindspore::profiler::ascend::MemoryProfiling;
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace ascend {
|
||||
using KernelGraph = mindspore::session::KernelGraph;
|
||||
const char kMsVm[] = "vm";
|
||||
constexpr size_t kAtomicCleanInputSize = 2;
|
||||
constexpr auto kUnknowErrorString = "Unknown error occurred";
|
||||
constexpr auto kAscend910 = "ascend910";
|
||||
namespace {
|
||||
CNodePtr GetNextLabelSet(const std::vector<CNodePtr> &kernel_nodes, uint32_t index) {
|
||||
size_t node_sizes = kernel_nodes.size();
|
||||
if (index >= node_sizes - 1) {
|
||||
MS_LOG(EXCEPTION) << "there is no node after this node:" << kernel_nodes[index]->DebugString();
|
||||
}
|
||||
auto kernel = kernel_nodes[index + 1];
|
||||
if (common::AnfAlgo::GetCNodeName(kernel) != kLabelSetOpName) {
|
||||
MS_LOG(EXCEPTION) << "the node is not labelset follow labelgoto/labelswitch, node: "
|
||||
<< kernel_nodes[index]->DebugString();
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
std::vector<CNodePtr> HandleRecursiveCall(const std::vector<CNodePtr> &kernel_cnodes, const uint32_t &back_label,
|
||||
uint32_t *index, std::vector<CNodePtr> *back) {
|
||||
MS_EXCEPTION_IF_NULL(index);
|
||||
MS_EXCEPTION_IF_NULL(back);
|
||||
std::vector<CNodePtr> front;
|
||||
std::vector<CNodePtr> back_temp;
|
||||
bool back_flag = false;
|
||||
uint32_t i = *index;
|
||||
while (i < kernel_cnodes.size()) {
|
||||
if (!back_flag) {
|
||||
front.emplace_back(kernel_cnodes[i]);
|
||||
} else {
|
||||
back->emplace_back(kernel_cnodes[i]);
|
||||
}
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrRecursiveEnd, kernel_cnodes[i])) {
|
||||
*index = i;
|
||||
back->insert(back->end(), back_temp.begin(), back_temp.end());
|
||||
return front;
|
||||
}
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrRecursive, kernel_cnodes[i])) {
|
||||
back_flag = true;
|
||||
if (!common::AnfAlgo::IsLabelIndexInNode(kernel_cnodes[i], back_label)) {
|
||||
auto temp = HandleRecursiveCall(kernel_cnodes, back_label, &(++i), &back_temp);
|
||||
front.insert(front.end(), temp.begin(), temp.end());
|
||||
}
|
||||
}
|
||||
i++;
|
||||
}
|
||||
return front;
|
||||
}
|
||||
|
||||
void UnfoldRecursiveExecOrder(KernelGraph *kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
if (!kernel_graph->recursive_call()) {
|
||||
return;
|
||||
}
|
||||
auto kernel_cnodes = kernel_graph->mem_reuse_exec_order();
|
||||
std::vector<CNodePtr> mem_reuse_order;
|
||||
mem_reuse_order.reserve(kernel_cnodes.size());
|
||||
for (uint32_t i = 0; i < kernel_cnodes.size(); i++) {
|
||||
if (!common::AnfAlgo::HasNodeAttr(kAttrRecursiveStart, kernel_cnodes[i])) {
|
||||
mem_reuse_order.emplace_back(kernel_cnodes[i]);
|
||||
continue;
|
||||
}
|
||||
auto label_id = common::AnfAlgo::GetNodeAttr<uint32_t>(kernel_cnodes[i], kAttrLabelIndex);
|
||||
std::vector<CNodePtr> back;
|
||||
auto front = HandleRecursiveCall(kernel_cnodes, label_id, &i, &back);
|
||||
mem_reuse_order.insert(mem_reuse_order.end(), front.begin(), front.end());
|
||||
mem_reuse_order.insert(mem_reuse_order.end(), back.begin(), back.end());
|
||||
}
|
||||
kernel_graph->set_mem_reuse_exec_order(mem_reuse_order);
|
||||
}
|
||||
|
||||
void GetSubGraphExecOrder(const KernelGraph *kernel_graph, uint32_t index, const CNodePtr &back_node,
|
||||
std::vector<CNodePtr> *mem_reuse_order) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
MS_EXCEPTION_IF_NULL(mem_reuse_order);
|
||||
auto label_id = common::AnfAlgo::GetNodeAttr<uint32_t>(back_node, kAttrLabelIndex);
|
||||
auto kernel_cnodes = kernel_graph->execution_order();
|
||||
for (auto i = index; i < kernel_cnodes.size(); i++) {
|
||||
mem_reuse_order->emplace_back(kernel_cnodes[i]);
|
||||
if (common::AnfAlgo::IsLabelIndexInNode(kernel_cnodes[i], label_id)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void InitMemReuseExecOrder(KernelGraph *kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
if (!kernel_graph->subgraph_multi_call()) {
|
||||
return;
|
||||
}
|
||||
std::unordered_map<uint32_t, uint32_t> label_id_index_map;
|
||||
auto kernel_cnodes = kernel_graph->execution_order();
|
||||
std::vector<CNodePtr> mem_reuse_order;
|
||||
for (uint32_t i = 0; i < kernel_cnodes.size(); i++) {
|
||||
mem_reuse_order.emplace_back(kernel_cnodes[i]);
|
||||
if (common::AnfAlgo::CheckPrimitiveType(kernel_cnodes[i], prim::kPrimLabelSwitch) &&
|
||||
!common::AnfAlgo::HasNodeAttr(kAttrRecursive, kernel_cnodes[i]) &&
|
||||
!common::AnfAlgo::HasNodeAttr(kAttrReturn, kernel_cnodes[i])) {
|
||||
auto label_list = common::AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(kernel_cnodes[i], kAttrLabelSwitchList);
|
||||
for (auto label_id : label_list) {
|
||||
if (label_id_index_map.find(label_id) == label_id_index_map.end()) {
|
||||
continue;
|
||||
}
|
||||
auto back_node = GetNextLabelSet(kernel_cnodes, i);
|
||||
GetSubGraphExecOrder(kernel_graph, label_id_index_map[label_id], back_node, &mem_reuse_order);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (common::AnfAlgo::CheckPrimitiveType(kernel_cnodes[i], prim::kPrimLabelGoto) &&
|
||||
!common::AnfAlgo::HasNodeAttr(kAttrRecursive, kernel_cnodes[i]) &&
|
||||
!common::AnfAlgo::HasNodeAttr(kAttrReturn, kernel_cnodes[i])) {
|
||||
auto label_id = common::AnfAlgo::GetNodeAttr<uint32_t>(kernel_cnodes[i], kAttrLabelIndex);
|
||||
if (label_id_index_map.find(label_id) == label_id_index_map.end()) {
|
||||
continue;
|
||||
}
|
||||
auto back_node = GetNextLabelSet(kernel_cnodes, i);
|
||||
GetSubGraphExecOrder(kernel_graph, label_id_index_map[label_id], back_node, &mem_reuse_order);
|
||||
continue;
|
||||
}
|
||||
if (common::AnfAlgo::CheckPrimitiveType(kernel_cnodes[i], prim::kPrimLabelSet) &&
|
||||
!common::AnfAlgo::HasNodeAttr(kAttrRecursive, kernel_cnodes[i])) {
|
||||
auto label_id = common::AnfAlgo::GetNodeAttr<uint32_t>(kernel_cnodes[i], kAttrLabelIndex);
|
||||
if (label_id_index_map.find(label_id) != label_id_index_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "Two labelsets with same label id.";
|
||||
}
|
||||
label_id_index_map[label_id] = i;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
kernel_graph->set_mem_reuse_exec_order(mem_reuse_order);
|
||||
UnfoldRecursiveExecOrder(kernel_graph);
|
||||
}
|
||||
|
||||
// Before creating the kernel, check whether the node has completed the operator selection. If not, the operator
|
||||
// selection needs to be performed to set kernel info.
|
||||
void SetKernelInfoBeforeCreateKernel(const std::vector<CNodePtr> &nodes) {
|
||||
// Check whether the node has completed kernel selection.
|
||||
for (const auto &node : nodes) {
|
||||
if (AnfAlgo::GetSelectKernelBuildInfo(node) != nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Kernel selection process.
|
||||
auto [status, msg, etype] = SelectKernelInfoWithMsg(node);
|
||||
if (status == device::ascend::kNoMatched) {
|
||||
MS_EXCEPTION(etype) << msg;
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
/*
|
||||
* Feature group: Dump.
|
||||
* Target device group: Ascend.
|
||||
* Runtime category: MindRT.
|
||||
* Description: Parse config json file and register callback to adx.
|
||||
*/
|
||||
#ifndef ENABLE_SECURITY
|
||||
void DumpInit(uint32_t device_id) {
|
||||
auto &json_parser = DumpJsonParser::GetInstance();
|
||||
json_parser.Parse();
|
||||
json_parser.CopyDumpJsonToDir(device_id);
|
||||
json_parser.CopyHcclJsonToDir(device_id);
|
||||
json_parser.CopyMSCfgJsonToDir(device_id);
|
||||
if (json_parser.async_dump_enabled()) {
|
||||
#if !(defined(ENABLE_TEST) || defined(ENABLE_TESTCASES))
|
||||
// register callback to adx
|
||||
if (json_parser.FileFormatIsNpy()) {
|
||||
AdxRegDumpProcessCallBack(DumpDataCallBack);
|
||||
}
|
||||
#endif
|
||||
if (AdxDataDumpServerInit() != 0) {
|
||||
MS_LOG(EXCEPTION) << "Adx data dump server init failed";
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void AscendDeviceContext::Initialize() {
|
||||
MS_LOG(INFO) << "Status record: Enter Initialize...";
|
||||
|
@ -261,44 +35,22 @@ void AscendDeviceContext::Initialize() {
|
|||
runtime_instance_->SetContext();
|
||||
return;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Status record: Initialize start...";
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
runtime_instance_ = dynamic_cast<AscendKernelRuntime *>(
|
||||
device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id));
|
||||
MS_EXCEPTION_IF_NULL(device_res_manager_);
|
||||
device_res_manager_->Initialize();
|
||||
auto ascend_res_manager = dynamic_cast<AscendDeviceResManager *>(device_res_manager_.get());
|
||||
MS_EXCEPTION_IF_NULL(ascend_res_manager);
|
||||
runtime_instance_ = ascend_res_manager->runtime_instance_;
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
||||
if (!runtime_instance_->Init()) {
|
||||
MS_LOG(EXCEPTION) << "Kernel runtime init error.";
|
||||
}
|
||||
mem_manager_ = runtime_instance_->GetMemoryManager();
|
||||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||
|
||||
auto env_rank_id = common::GetEnv("RANK_ID");
|
||||
if (ms_context->get_param<bool>(MS_CTX_ENABLE_HCCL) && !env_rank_id.empty()) {
|
||||
// get actual rank id if it's distribution training case.
|
||||
rank_id_ = GetRankId();
|
||||
}
|
||||
#ifndef ENABLE_SECURITY
|
||||
DumpInit(rank_id_);
|
||||
#endif
|
||||
compute_stream_ = runtime_instance_->compute_stream();
|
||||
communication_stream_ = runtime_instance_->communication_stream();
|
||||
|
||||
// Initialize tbe using HCCL rank_id
|
||||
kernel::ascend::TbeKernelCompileManager::GetInstance().TbeInitialize();
|
||||
|
||||
auto ascend_kernel_executor = dynamic_cast<AscendKernelExecutor *>(kernel_executor_.get());
|
||||
MS_EXCEPTION_IF_NULL(ascend_kernel_executor);
|
||||
ascend_kernel_executor->Initialize();
|
||||
auto ascend_graph_executor = dynamic_cast<AscendGraphExecutor *>(graph_executor_.get());
|
||||
MS_EXCEPTION_IF_NULL(ascend_graph_executor);
|
||||
ascend_graph_executor->Initialize();
|
||||
initialized_ = true;
|
||||
MS_LOG(INFO) << "Status record: Initialize success.";
|
||||
}
|
||||
|
||||
bool AscendDeviceContext::IsGraphMode() {
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
return context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode;
|
||||
}
|
||||
|
||||
void AscendDeviceContext::Destroy() {
|
||||
#ifdef ENABLE_DEBUGGER
|
||||
auto debugger = Debugger::GetInstance();
|
||||
|
@ -314,17 +66,16 @@ void AscendDeviceContext::Destroy() {
|
|||
if (!initialized_) {
|
||||
return;
|
||||
}
|
||||
MS_LOG(INFO) << "Status record: Destroy start...";
|
||||
if (DataQueueMgr::GetInstance().IsInit()) {
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(DataQueueMgr::GetInstance().Destroy(), "Could not destroy ascend data queue.");
|
||||
}
|
||||
|
||||
graph_event_.clear();
|
||||
rank_id_ = 0;
|
||||
MS_LOG(INFO) << "Status record: Destroy start...";
|
||||
auto ascend_graph_executor = dynamic_cast<AscendGraphExecutor *>(graph_executor_.get());
|
||||
ascend_graph_executor->Destroy();
|
||||
auto ascend_kernel_executor = dynamic_cast<AscendKernelExecutor *>(kernel_executor_.get());
|
||||
ascend_kernel_executor->Destroy();
|
||||
device_res_manager_->Destroy();
|
||||
if (runtime_instance_) {
|
||||
runtime_instance_ = nullptr;
|
||||
}
|
||||
AscendGraphOptimization::GetInstance().Reset();
|
||||
initialized_ = false;
|
||||
MS_LOG(INFO) << "Status record: Destroy success.";
|
||||
}
|
||||
|
@ -336,13 +87,6 @@ bool AscendDeviceContext::PartitionGraph(const FuncGraphPtr &func_graph) const {
|
|||
return context_ptr->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK);
|
||||
}
|
||||
|
||||
bool IsDynamicShapeGraph(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
|
||||
return std::any_of(node_list.begin(), node_list.end(),
|
||||
[](const AnfNodePtr &node) { return common::AnfAlgo::IsDynamicShape(node); });
|
||||
}
|
||||
|
||||
RunMode AscendDeviceContext::GetRunMode(const FuncGraphPtr &func_graph) const {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
|
@ -353,663 +97,6 @@ RunMode AscendDeviceContext::GetRunMode(const FuncGraphPtr &func_graph) const {
|
|||
}
|
||||
}
|
||||
|
||||
void AscendDeviceContext::UnifyMindIR(const KernelGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
AscendGraphOptimization::GetInstance().UnifyMindIR(graph);
|
||||
}
|
||||
|
||||
void AscendDeviceContext::OptimizeGraph(const FuncGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
if (kernel_graph->is_from_single_op()) {
|
||||
AscendGraphOptimization::GetInstance().OptimizeSingleOpGraph(kernel_graph);
|
||||
} else {
|
||||
AscendGraphOptimization::GetInstance().OptimizeGraph(kernel_graph);
|
||||
}
|
||||
}
|
||||
|
||||
void AscendDeviceContext::CreateKernel(const std::vector<CNodePtr> &nodes) const {
|
||||
SetKernelInfoBeforeCreateKernel(nodes);
|
||||
|
||||
MS_LOG(INFO) << "Status record: start create kernel.";
|
||||
PROF_START(create_kernel);
|
||||
auto ret = device::ascend::KernelBuild(nodes);
|
||||
if (!ret) {
|
||||
MS_LOG(EXCEPTION) << "Kernel build error.";
|
||||
}
|
||||
PROF_END(create_kernel);
|
||||
MS_LOG(INFO) << "Status record: end create kernel.";
|
||||
}
|
||||
|
||||
void AscendDeviceContext::LaunchDeviceLibrary() const {
|
||||
MS_LOG(INFO) << "Status record: start launch device library.";
|
||||
auto ret = mindspore::kernel::AicpuOpKernelLoad::GetInstance().LaunchAicpuKernelSo();
|
||||
if (!ret) {
|
||||
MS_LOG(EXCEPTION) << "Cust aicpu kernel so load failed.";
|
||||
}
|
||||
MS_LOG(INFO) << "Status record: end launch device library.";
|
||||
}
|
||||
|
||||
void AscendDeviceContext::UpdateExecOrder(const KernelGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
std::vector<CNodePtr> new_orders;
|
||||
auto nodes = graph->execution_order();
|
||||
for (const auto &node : nodes) {
|
||||
if (node_atomics_.find(node) != node_atomics_.end()) {
|
||||
auto atomics = node_atomics_[node];
|
||||
(void)std::copy(atomics.begin(), atomics.end(), std::back_inserter(new_orders));
|
||||
}
|
||||
new_orders.push_back(node);
|
||||
}
|
||||
graph->set_execution_order(new_orders);
|
||||
node_atomics_.clear();
|
||||
}
|
||||
|
||||
void AscendDeviceContext::SetAtomicCleanToNodes(const KernelGraphPtr &graph,
|
||||
const std::map<CNodePtr, std::vector<CNodePtr>> &atomics_node) const {
|
||||
// don't clear node_atomics_ in the end, since atomic_clean_nodes_ in kernel.h is weakptr
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto nodes = graph->execution_order();
|
||||
for (const auto &node : nodes) {
|
||||
auto it = atomics_node.find(node);
|
||||
if (it != atomics_node.end()) {
|
||||
const auto &atomics = it->second;
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(node);
|
||||
auto ascend_kernel_mod = dynamic_cast<kernel::AscendKernelMod *>(kernel_mod);
|
||||
if (ascend_kernel_mod != nullptr) {
|
||||
ascend_kernel_mod->SetAtomicCleanNodes(atomics);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AscendDeviceContext::PreprocessBeforeRun(const FuncGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
if (kernel_graph->is_from_single_op()) {
|
||||
PreprocessBeforeRunSingleOpGraph(kernel_graph);
|
||||
} else {
|
||||
PreprocessBeforeRunGraph(kernel_graph);
|
||||
}
|
||||
}
|
||||
|
||||
void AscendDeviceContext::PreprocessBeforeRunGraph(const KernelGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_LOG(INFO) << "Status record: start preprocess before run graph. graph id: " << graph->graph_id();
|
||||
PROF_START(preprocess_before_run_graph);
|
||||
auto ascend_instance = profiler::ascend::AscendProfiler::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ascend_instance);
|
||||
if (graph->is_dynamic_shape()) {
|
||||
ascend_instance->SetNetDynamicShapeStatus();
|
||||
}
|
||||
SetErrorManagerContext();
|
||||
try {
|
||||
if (graph->is_graph_run_mode()) {
|
||||
device::ascend::InsertAtomicCleanOps(graph->execution_order(), &node_atomics_);
|
||||
UpdateExecOrder(graph);
|
||||
device::KernelAdjust::GetInstance().InsertDeviceLoopCtrl(graph);
|
||||
device::KernelAdjust::GetInstance().ProcessLoopSink(graph);
|
||||
AscendStreamAssign::GetInstance().AssignStream(NOT_NULL(graph));
|
||||
#ifndef ENABLE_SECURITY
|
||||
// Insert profiling point, this function must be executed after assign stream.
|
||||
device::KernelAdjust::GetInstance().Profiling(NOT_NULL(graph.get()));
|
||||
#endif
|
||||
CreateKernel(graph->execution_order());
|
||||
AllocateGraphMemory(NOT_NULL(graph));
|
||||
LoadModel(NOT_NULL(graph));
|
||||
AssignOutputNopNodeDeviceAddress(graph);
|
||||
} else if (graph->is_dynamic_shape() && (IsGraphMode() || graph->has_flag(kFlagPyNativeRunInGraph))) {
|
||||
device::ascend::InsertAtomicCleanOps(graph->execution_order(), &node_atomics_);
|
||||
SetAtomicCleanToNodes(graph, node_atomics_); // graph mode may can do it too, instead of update execorder
|
||||
opt::DynamicShapeConvertPass(graph);
|
||||
AscendStreamAssign::GetInstance().AssignStream(NOT_NULL(graph));
|
||||
AssignOutputNopNodeDeviceAddress(graph);
|
||||
LaunchDeviceLibrary();
|
||||
} else {
|
||||
PreprocessBeforeRunSingleOpGraph(graph);
|
||||
AscendStreamAssign::GetInstance().AssignStream(NOT_NULL(graph));
|
||||
CreateKernel(graph->execution_order());
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
||||
runtime_instance_->SetKernelModRtStream(NOT_NULL(graph));
|
||||
}
|
||||
} catch (const std::exception &e) {
|
||||
ReportErrorMessage();
|
||||
MS_LOG(EXCEPTION) << "Preprocess failed before run graph " << graph->graph_id() << ", \nerror msg: " << e.what();
|
||||
}
|
||||
|
||||
const std::vector<CNodePtr> &kernels = graph->execution_order();
|
||||
for (const auto &kernel : kernels) {
|
||||
common::AnfAlgo::SetNodeAttr(kAttrMSFunction, MakeValue(true), kernel);
|
||||
}
|
||||
|
||||
PROF_END(preprocess_before_run_graph);
|
||||
MS_LOG(INFO) << "Status record: end preprocess before run graph. graph id: " << graph->graph_id();
|
||||
}
|
||||
|
||||
void AscendDeviceContext::AssignOutputNopNodeDeviceAddress(const KernelGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto outputs = common::AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem});
|
||||
for (auto output : outputs) {
|
||||
if (!output->isa<CNode>() || !AnfUtils::IsRealKernel(output)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!common::AnfAlgo::IsNopNode(output)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!common::AnfAlgo::IsNeedSkipNopOpAddr(output)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(output);
|
||||
if (input_num != 1) {
|
||||
MS_LOG(WARNING) << "The input number of nop node :" << output->fullname_with_scope() << " is " << input_num
|
||||
<< ", not equal 1";
|
||||
continue;
|
||||
}
|
||||
|
||||
auto real_input_index = AnfAlgo::GetRealInputIndex(output, 0);
|
||||
auto pre_node_out_device_address = AnfAlgo::GetPrevNodeOutputAddr(output, real_input_index);
|
||||
MS_EXCEPTION_IF_NULL(pre_node_out_device_address);
|
||||
auto ptr = pre_node_out_device_address->GetPtr();
|
||||
auto size = pre_node_out_device_address->GetSize();
|
||||
std::string output_format = AnfAlgo::GetOutputFormat(output, 0);
|
||||
auto output_type = AnfAlgo::GetOutputDeviceDataType(output, 0);
|
||||
auto device_address = CreateDeviceAddress(const_cast<void *>(ptr), size, output_format, output_type);
|
||||
device_address->set_is_ptr_persisted(true);
|
||||
device_address->set_host_shape(trans::GetRuntimePaddingShape(output, 0));
|
||||
AnfAlgo::SetOutputAddr(device_address, 0, output.get());
|
||||
common::AnfAlgo::SetNodeAttr(kAttrSkipNopOpAddr, MakeValue(false), output);
|
||||
MS_LOG(INFO) << "Assign device address to output nop node " << output->fullname_with_scope();
|
||||
}
|
||||
}
|
||||
|
||||
void AscendDeviceContext::AllocateGraphMemory(const NotNull<KernelGraphPtr> &root_graph) const {
|
||||
MS_LOG(INFO) << "Status record: start memory alloc. graph id: " << root_graph->graph_id();
|
||||
PROF_START(graph_memory_alloc);
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
||||
runtime_instance_->ClearGlobalIdleMem();
|
||||
std::set<KernelGraphPtr> memo;
|
||||
memo.clear();
|
||||
mem_manager_->ResetDynamicMemory();
|
||||
AssignInputMemory(root_graph, NOT_NULL(&memo));
|
||||
device::KernelAdjust::GetInstance().AssignLoopCtrlMemory(*root_graph.get());
|
||||
InitMemReuseExecOrder(root_graph.get().get());
|
||||
runtime_instance_->SetReuseCommunicationAddress(*root_graph.get());
|
||||
runtime_instance_->AssignStaticMemoryOutput(*root_graph.get());
|
||||
runtime_instance_->AssignDynamicMemory(*root_graph.get());
|
||||
runtime_instance_->UpdateRefNodeOutputMem(*root_graph.get());
|
||||
|
||||
PROF_END(graph_memory_alloc);
|
||||
MS_LOG(INFO) << "Status record: end memory alloc. graph id: " << root_graph->graph_id()
|
||||
<< ", Memory Statistics: " << device::ascend::AscendMemAdapter::GetInstance().DevMemStatistics();
|
||||
MS_LOG(INFO) << "The dynamic memory pool total size is: "
|
||||
<< device::ascend::AscendMemoryPool::GetInstance().TotalMemStatistics() / kMBToByte
|
||||
<< "M, total used size is "
|
||||
<< device::ascend::AscendMemoryPool::GetInstance().TotalUsedMemStatistics() / kMBToByte
|
||||
<< "M, used peak size is "
|
||||
<< device::ascend::AscendMemoryPool::GetInstance().UsedMemPeakStatistics() / kMBToByte << "M.";
|
||||
|
||||
#ifndef ENABLE_SECURITY
|
||||
if (MemoryProfiling::GetInstance().IsMemoryProfilingInitialized()) {
|
||||
uint64_t mem_size = runtime_instance_->GetMsUsedHbmSize();
|
||||
MemoryProfiling::GetInstance().SetDeviceMemSize(mem_size);
|
||||
if (MemoryProfiling::GetInstance().NeedSaveMemoryProfiling()) {
|
||||
MemoryProfiling::GetInstance().SaveMemoryProfiling();
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void AscendDeviceContext::AssignInputMemory(const NotNull<KernelGraphPtr> &graph,
|
||||
NotNull<std::set<KernelGraphPtr> *> const memo) const {
|
||||
if (memo->find(graph) != memo->end()) {
|
||||
return;
|
||||
}
|
||||
memo->insert(graph.get());
|
||||
|
||||
MS_LOG(INFO) << "Start to assign static memory for Parameter and Value node in graph: " << graph->graph_id();
|
||||
runtime_instance_->AssignStaticMemoryInput(*graph.get());
|
||||
runtime_instance_->AssignStaticMemoryValueNode(*graph.get());
|
||||
for (auto &child_graph : graph->child_graph_order()) {
|
||||
AssignInputMemory(NOT_NULL(child_graph.lock()), memo);
|
||||
}
|
||||
MS_LOG(INFO) << "Finish assigning static memory for Parameter and Value node in graph: " << graph->graph_id();
|
||||
}
|
||||
|
||||
void AscendDeviceContext::LoadModel(const NotNull<KernelGraphPtr> &root_graph) const {
|
||||
MS_LOG(INFO) << "Status record: start load model. graph id: " << root_graph->graph_id();
|
||||
PROF_START(load_model);
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
||||
bool ret_ok = runtime_instance_->Load(*root_graph.get(), true);
|
||||
if (!ret_ok) {
|
||||
MS_LOG(EXCEPTION) << "Load task error!";
|
||||
}
|
||||
PROF_END(load_model);
|
||||
MS_LOG(INFO) << "Status record: end load model. graph id: " << root_graph->graph_id();
|
||||
}
|
||||
|
||||
void *AscendDeviceContext::AllocateMemory(size_t size) const {
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
||||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||
runtime_instance_->SetContext();
|
||||
return mem_manager_->MallocMemFromMemPool(size, false);
|
||||
}
|
||||
|
||||
void AscendDeviceContext::FreeMemory(void *ptr) const {
|
||||
MS_EXCEPTION_IF_NULL(ptr);
|
||||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||
mem_manager_->FreeMemFromMemPool(ptr);
|
||||
}
|
||||
|
||||
std::vector<void *> AscendDeviceContext::AllocateContinuousMemory(const std::vector<size_t> &size_list) const {
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
||||
runtime_instance_->SetContext();
|
||||
std::vector<size_t> align_size_list;
|
||||
for (size_t i = 0; i < size_list.size(); i++) {
|
||||
auto align_size = device::MemoryManager::GetCommonAlignSize(size_list[i]);
|
||||
align_size_list.emplace_back(align_size);
|
||||
}
|
||||
return mem_manager_->MallocContinuousMemFromMemPool(align_size_list);
|
||||
}
|
||||
|
||||
bool AscendDeviceContext::ExecuteGraph(const KernelGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
const uint64_t kUSecondInSecond = 1000000;
|
||||
bool ret = false;
|
||||
if (graph->is_graph_run_mode()) {
|
||||
InsertEventBeforeRunTask(graph);
|
||||
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
auto start_time = std::chrono::steady_clock::now();
|
||||
#else
|
||||
struct timeval start_time {};
|
||||
struct timeval end_time {};
|
||||
(void)gettimeofday(&start_time, nullptr);
|
||||
#endif
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
||||
{
|
||||
std::lock_guard<std::mutex> locker(launch_mutex_);
|
||||
ret = runtime_instance_->RunTask(*graph);
|
||||
}
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
auto end_time = std::chrono::steady_clock::now();
|
||||
std::chrono::duration<double, std::ratio<1, kUSecondInSecond>> cost = end_time - start_time;
|
||||
MS_LOG(INFO) << "Call MS Run Success in " << cost.count() << " us";
|
||||
#else
|
||||
(void)gettimeofday(&end_time, nullptr);
|
||||
uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
|
||||
cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
|
||||
MS_LOG(INFO) << "Call MS Run Success in " << cost << " us";
|
||||
#endif
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << graph->ToString() << " does not sink, should launch kernels";
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool AscendDeviceContext::LaunchGraph(const KernelGraphPtr &graph) const {
|
||||
MS_LOG(INFO) << "Status record: start launch graph. graph id: " << graph->graph_id();
|
||||
PROF_START(launch_graph);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
||||
runtime_instance_->SetContext();
|
||||
SetErrorManagerContext();
|
||||
device::KernelAdjust::GetInstance().LoadDeviceLoopCtrlParameters(graph);
|
||||
auto ret = ExecuteGraph(graph);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "run task error!";
|
||||
ReportErrorMessage();
|
||||
return ret;
|
||||
}
|
||||
ReportWarningMessage();
|
||||
PROF_END(launch_graph);
|
||||
MS_LOG(INFO) << "Status record: end launch graph. graph id: " << graph->graph_id();
|
||||
return ret;
|
||||
}
|
||||
|
||||
void AscendDeviceContext::ReportErrorMessage() const {
|
||||
const string &error_message = ErrorManager::GetInstance().GetErrorMessage();
|
||||
if (!error_message.empty() && error_message.find(kUnknowErrorString) == string::npos) {
|
||||
MS_LOG(ERROR) << "Ascend error occurred, error message:\n" << error_message;
|
||||
}
|
||||
}
|
||||
|
||||
void AscendDeviceContext::SetErrorManagerContext() const { ErrorManager::GetInstance().GenWorkStreamIdDefault(); }
|
||||
|
||||
void AscendDeviceContext::ReportWarningMessage() const {
|
||||
const string &warning_message = ErrorManager::GetInstance().GetWarningMessage();
|
||||
if (!warning_message.empty()) {
|
||||
MS_LOG(WARNING) << "Ascend warning message:\n" << warning_message;
|
||||
}
|
||||
}
|
||||
|
||||
bool AscendDeviceContext::SyncStream(size_t stream_id) const {
|
||||
auto iter = stream_ids_.find(stream_id);
|
||||
if (iter != stream_ids_.end()) {
|
||||
MS_EXCEPTION_IF_NULL(iter->second);
|
||||
auto ret = rtStreamSynchronize(iter->second);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Failed to synchronize ascend stream, ret[" << ret << "]";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
||||
return runtime_instance_->SyncStream();
|
||||
}
|
||||
|
||||
bool AscendDeviceContext::CreateStream(void **stream) const {
|
||||
MS_EXCEPTION_IF_NULL(stream);
|
||||
auto ret = rtStreamCreate(stream, 0);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Failed to create ascend stream, ret[" << ret << "]";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AscendDeviceContext::DestroyStream(void *stream) const {
|
||||
MS_EXCEPTION_IF_NULL(stream);
|
||||
auto ret = rtStreamDestroy(stream);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Failed to destroy ascend stream, ret[" << ret << "]";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void AscendDeviceContext::PreprocessBeforeRunSingleOpGraph(const KernelGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
const auto &nodes = graph->execution_order();
|
||||
|
||||
for (const auto &node : nodes) {
|
||||
// Remove placeholder
|
||||
auto op_name = common::AnfAlgo::GetCNodeName(node);
|
||||
static const std::set<std::string> place_holder_nodes = {kDynamicRNNOpName, kDynamicGRUV2OpName};
|
||||
auto iter = place_holder_nodes.find(op_name);
|
||||
if (iter != place_holder_nodes.end()) {
|
||||
auto none_index = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, kAttrPlaceHolderIndex);
|
||||
// Remove seq_length
|
||||
auto input_num = common::AnfAlgo::GetInputTensorNum(node);
|
||||
std::vector<AnfNodePtr> new_inputs = {common::AnfAlgo::GetCNodePrimitiveNode(node)};
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
auto item = std::find(none_index.begin(), none_index.end(), i);
|
||||
if (item == none_index.end()) {
|
||||
auto input_node = common::AnfAlgo::GetInputNode(node, i);
|
||||
new_inputs.emplace_back(input_node);
|
||||
}
|
||||
}
|
||||
node->set_inputs(new_inputs);
|
||||
}
|
||||
|
||||
// Save the nop_op that needs to be memcpy
|
||||
static mindspore::HashSet<std::string> nop_nodes = {prim::kPrimReshape->name(), prim::kPrimExpandDims->name(),
|
||||
prim::kPrimSqueeze->name(), prim::kPrimFlatten->name(),
|
||||
prim::kPrimFlattenGrad->name()};
|
||||
// If the 2nd input of reshape is not a value node, then there are two inputs to select the host reshape operator
|
||||
bool is_host_reshape_op = false;
|
||||
if (op_name == prim::kPrimReshape->name()) {
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(node);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
is_host_reshape_op = kernel_mod->GetKernelModType() == kernel::KernelModType::HostKernelMod;
|
||||
}
|
||||
bool nop_op_is_not_dynamic_shape = !graph->is_dynamic_shape() && nop_nodes.find(op_name) != nop_nodes.end();
|
||||
bool is_transpose_nop = op_name == prim::kPrimTranspose->name() && common::AnfAlgo::HasNodeAttr(kAttrNopOp, node);
|
||||
if (is_transpose_nop || (nop_op_is_not_dynamic_shape && !is_host_reshape_op)) {
|
||||
nop_op_to_memcpy_.insert(node);
|
||||
}
|
||||
}
|
||||
|
||||
device::ascend::InsertAtomicCleanOps(nodes, &node_atomics_persistent_cache_);
|
||||
std::vector<CNodePtr> atomic_nodes;
|
||||
for (const auto &node : nodes) {
|
||||
auto iter = node_atomics_persistent_cache_.find(node);
|
||||
if (iter != node_atomics_persistent_cache_.end()) {
|
||||
const auto &atomics = iter->second;
|
||||
std::copy(atomics.begin(), atomics.end(), std::back_inserter(atomic_nodes));
|
||||
}
|
||||
}
|
||||
|
||||
SetAtomicCleanToNodes(graph, node_atomics_persistent_cache_);
|
||||
CreateKernel(atomic_nodes);
|
||||
LaunchDeviceLibrary();
|
||||
}
|
||||
|
||||
std::shared_ptr<Bucket> AscendDeviceContext::CreateBucket(uint32_t bucket_id, uint32_t bucket_size) const {
|
||||
auto bucket = std::make_shared<AscendBucket>(bucket_id, bucket_size);
|
||||
MS_EXCEPTION_IF_NULL(bucket);
|
||||
|
||||
// For data-parallel, there is no communication in forward and backward process, the only communication ops arise
|
||||
// from this allreduce bucket. All the ops in forward and backward process are assigned on the compute stream and
|
||||
// allreduce for gradients is assigned on communication stream.
|
||||
// But for semi/auto_parallel mode, there will be communication ops in forward and backward process. To avoid stream
|
||||
// sync error, for semi/auto_parallel mode, the allreduce for gradients is assigned on compute stream as well.
|
||||
auto parallel_context = parallel::ParallelContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(parallel_context);
|
||||
auto parallel_mode = parallel_context->parallel_mode();
|
||||
if (parallel_mode == parallel::kAutoParallel || parallel_mode == parallel::kSemiAutoParallel) {
|
||||
bucket->Init({compute_stream_}, {compute_stream_});
|
||||
} else {
|
||||
bucket->Init({compute_stream_}, {communication_stream_});
|
||||
}
|
||||
return bucket;
|
||||
}
|
||||
|
||||
bool AscendDeviceContext::PySyncRuning() const {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if ((ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) &&
|
||||
ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE) && !SyncStream()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AscendDeviceContext::MemoryCopyAsync(const CNodePtr &node, const vector<AddressPtr> &inputs,
|
||||
const vector<AddressPtr> &outputs) const {
|
||||
MS_LOG(DEBUG) << "Launch MemoryCopyAsync instead for kernel " << node->fullname_with_scope();
|
||||
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||
MS_LOG(ERROR) << "Kernel " << node->fullname_with_scope() << " input output size should be 1 but"
|
||||
<< " input size is:" << inputs.size() << " output size is:" << outputs.size();
|
||||
return false;
|
||||
}
|
||||
|
||||
aclError status = aclrtMemcpyAsync(outputs[0]->addr, outputs[0]->size, inputs[0]->addr, inputs[0]->size,
|
||||
ACL_MEMCPY_DEVICE_TO_DEVICE, compute_stream_);
|
||||
if (status != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "MemCpyAsync op aclrtMemcpyAsync failed, ret:" << status;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void *AscendDeviceContext::GetKernelStream(const CNodePtr &node) const {
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(node);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
|
||||
return compute_stream_;
|
||||
} else if (common::AnfAlgo::HasNodeAttr(kAttrStream, node)) {
|
||||
auto stream_id = common::AnfAlgo::GetNodeAttr<size_t>(node, kAttrStream);
|
||||
auto iter = stream_ids_.find(stream_id);
|
||||
if (iter == stream_ids_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Can not find stream for stream id: " << stream_id;
|
||||
}
|
||||
void *stream = iter->second;
|
||||
MS_EXCEPTION_IF_NULL(stream);
|
||||
return stream;
|
||||
} else {
|
||||
auto stream = kernel_mod->stream();
|
||||
if (stream == nullptr) {
|
||||
stream = compute_stream_;
|
||||
MS_LOG(INFO) << "Assign default compute stream for node " << node->fullname_with_scope();
|
||||
}
|
||||
return stream;
|
||||
}
|
||||
}
|
||||
|
||||
bool AscendDeviceContext::GetKernelRealInputs(const CNodePtr &kernel, const vector<AddressPtr> &inputs,
|
||||
std::vector<AddressPtr> *real_inputs) const {
|
||||
auto input_num = common::AnfAlgo::GetInputTensorNum(kernel);
|
||||
if (input_num != inputs.size()) {
|
||||
MS_LOG(ERROR) << "Input num is " << input_num << " but input address num is " << inputs.size();
|
||||
return false;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
auto real_index = AnfAlgo::GetRealInputIndex(kernel, i);
|
||||
if (real_index >= input_num) {
|
||||
MS_LOG(ERROR) << "Total input num is " << input_num << " but get real_index " << real_index;
|
||||
return false;
|
||||
}
|
||||
real_inputs->push_back(inputs[real_index]);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AscendDeviceContext::LaunchKernel(const CNodePtr &kernel, const vector<AddressPtr> &inputs,
|
||||
const vector<AddressPtr> &workspace, const vector<AddressPtr> &outputs) const {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
auto graph_id = AnfAlgo::GetGraphId(kernel.get());
|
||||
auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
KernelType kernel_type = AnfAlgo::GetKernelType(kernel);
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
MS_LOG(DEBUG) << "Launch kernel: " << kernel->fullname_with_scope();
|
||||
BindDeviceToCurrentThread();
|
||||
|
||||
std::vector<AddressPtr> real_inputs;
|
||||
bool ret = GetKernelRealInputs(kernel, inputs, &real_inputs);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "Get real input fail for kernel " << kernel->fullname_with_scope();
|
||||
return false;
|
||||
}
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
|
||||
bool is_dynamic_shape = common::AnfAlgo::IsDynamicShape(kernel);
|
||||
if (!is_dynamic_shape || !(common::AnfAlgo::GetBooleanAttr(kernel, kAttrMSFunction))) {
|
||||
std::lock_guard<std::mutex> locker(launch_mutex_);
|
||||
// launch atomic clean
|
||||
if (!LaunchAtomicClean(kernel, workspace, outputs)) {
|
||||
MS_LOG(ERROR) << "Launch AtomicClean failed, pre kernel full name: " << kernel->fullname_with_scope();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// launch kernel
|
||||
if (nop_op_to_memcpy_.find(kernel) != nop_op_to_memcpy_.end()) {
|
||||
MemoryCopyAsync(kernel, real_inputs, outputs);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Launch kernel " << kernel->fullname_with_scope();
|
||||
auto stream = GetKernelStream(kernel);
|
||||
#ifndef ENABLE_SECURITY
|
||||
auto profiler_inst = profiler::ascend::PynativeProfiler::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(profiler_inst);
|
||||
std::thread::id t_id = std::this_thread::get_id();
|
||||
(void)profiler_inst->OpDataProducerBegin(runtime_instance_, stream, t_id, kernel->fullname_with_scope(),
|
||||
is_dynamic_shape);
|
||||
#endif
|
||||
ret = kernel_mod->Launch(real_inputs, workspace, outputs, stream);
|
||||
#ifndef ENABLE_SECURITY
|
||||
(void)profiler_inst->OpDataProducerEnd(t_id, is_dynamic_shape);
|
||||
#endif
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "Launch kernel failed, kernel full name: " << kernel->fullname_with_scope();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
auto ascend_instance = profiler::ascend::AscendProfiler::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ascend_instance);
|
||||
if (ascend_instance->GetNetDynamicShapeStatus() && ascend_instance->GetProfilingEnableFlag()) {
|
||||
ascend_instance->GetNodeTaskIdStreamId(kernel, graph_id, device_id, kernel_type);
|
||||
}
|
||||
|
||||
return PySyncRuning();
|
||||
}
|
||||
|
||||
bool AscendDeviceContext::BindDeviceToCurrentThread() const {
|
||||
if (initialized_) {
|
||||
runtime_instance_->SetContext();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AscendDeviceContext::LaunchAtomicClean(const CNodePtr &node, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) const {
|
||||
auto iter = node_atomics_persistent_cache_.find(node);
|
||||
if (iter == node_atomics_persistent_cache_.end()) {
|
||||
return true;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Launch atomic clean for kernel " << node->fullname_with_scope();
|
||||
auto atomic_node = iter->second.at(0);
|
||||
vector<AddressPtr> atomic_inputs;
|
||||
// The output addr need to clean
|
||||
MS_EXCEPTION_IF_NULL(atomic_node);
|
||||
if (atomic_node->inputs().size() != kAtomicCleanInputSize) {
|
||||
MS_LOG(EXCEPTION) << "Atomic Addr clean Node Input nodes not equal 2.";
|
||||
}
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, node)) {
|
||||
auto clean_output_indexes = common::AnfAlgo::GetNodeAttr<std::vector<size_t>>(node, kAttrAtomicOutputIndexs);
|
||||
for (auto output_index : clean_output_indexes) {
|
||||
if (output_index >= outputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid output_index:" << output_index << " except less than " << outputs.size();
|
||||
}
|
||||
atomic_inputs.push_back(outputs[output_index]);
|
||||
}
|
||||
}
|
||||
|
||||
// The workspace addr need to clean
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, node)) {
|
||||
auto clean_workspace_indexes = common::AnfAlgo::GetNodeAttr<std::vector<size_t>>(node, kAttrAtomicWorkspaceIndexs);
|
||||
for (auto workspace_index : clean_workspace_indexes) {
|
||||
if (workspace_index >= workspace.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid workspace_index:" << workspace_index << " except less than " << workspace.size();
|
||||
}
|
||||
atomic_inputs.push_back(workspace[workspace_index]);
|
||||
}
|
||||
}
|
||||
// Launch Atomic Node
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(atomic_node);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
return kernel_mod->Launch(atomic_inputs, {}, {}, GetKernelStream(node));
|
||||
}
|
||||
|
||||
void AscendDeviceContext::InsertEventBeforeRunTask(const KernelGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (!graph->is_graph_run_mode() || graph->is_dynamic_shape()) {
|
||||
return;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Insert event between PyNative and Graph";
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
||||
auto model_stream = runtime_instance_->GetModelStream(graph->graph_id());
|
||||
auto compute_event = runtime_instance_->CreateDeviceEvent();
|
||||
MS_EXCEPTION_IF_NULL(compute_event);
|
||||
compute_event->set_wait_stream(model_stream);
|
||||
compute_event->set_record_stream(compute_stream_);
|
||||
compute_event->RecordEvent();
|
||||
compute_event->WaitEvent();
|
||||
graph_event_[graph->graph_id()] = compute_event;
|
||||
}
|
||||
|
||||
DeviceAddressPtr AscendDeviceContext::CreateDeviceAddress(void *const device_ptr, size_t device_size,
|
||||
const string &format, TypeId type_id,
|
||||
const ShapeVector &shape) const {
|
||||
auto device_address = std::make_shared<AscendDeviceAddress>(
|
||||
device_ptr, device_size, format, type_id, device_context_key_.device_name_, device_context_key_.device_id_);
|
||||
if (shape.empty()) {
|
||||
MS_LOG(DEBUG) << "shape size is empty.";
|
||||
}
|
||||
device_address->set_host_shape(shape);
|
||||
return device_address;
|
||||
}
|
||||
|
||||
MS_REGISTER_DEVICE(kAscendDevice, AscendDeviceContext);
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
|
|
|
@ -24,17 +24,21 @@
|
|||
#include <map>
|
||||
#include "runtime/hardware/device_context.h"
|
||||
#include "runtime/hardware/device_context_manager.h"
|
||||
#include "runtime/device/memory_manager.h"
|
||||
#include "plugin/device/ascend/hal/device/ascend_kernel_runtime.h"
|
||||
#include "plugin/device/ascend/hal/device/ascend_device_address.h"
|
||||
#include "plugin/device/ascend/hal/hardware/ascend_device_res_manager.h"
|
||||
#include "plugin/device/ascend/hal/hardware/ascend_kernel_executor.h"
|
||||
#include "plugin/device/ascend/hal/hardware/ascend_graph_executor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace ascend {
|
||||
class AscendDeviceContext : public DeviceContext {
|
||||
class AscendGraphExecutor;
|
||||
class AscendKernelExecutor;
|
||||
class AscendDeviceResManager;
|
||||
|
||||
class AscendDeviceContext : public DeviceInterface<AscendGraphExecutor, AscendKernelExecutor, AscendDeviceResManager> {
|
||||
public:
|
||||
explicit AscendDeviceContext(const DeviceContextKey &device_context_key)
|
||||
: DeviceContext(device_context_key), mem_manager_(nullptr), initialized_(false) {}
|
||||
: DeviceInterface(device_context_key), initialized_(false) {}
|
||||
~AscendDeviceContext() override = default;
|
||||
|
||||
// Initialize the device context.
|
||||
|
@ -43,115 +47,13 @@ class AscendDeviceContext : public DeviceContext {
|
|||
// Destroy device context and release device resource.
|
||||
void Destroy() override;
|
||||
|
||||
// Get rank id for distributed training.
|
||||
uint32_t GetRankID() const override { return rank_id_; }
|
||||
|
||||
bool PartitionGraph(const FuncGraphPtr &func_graph) const override;
|
||||
|
||||
RunMode GetRunMode(const FuncGraphPtr &func_graph) const override;
|
||||
// Optimize the kernel graph for graph mode.
|
||||
void OptimizeGraph(const FuncGraphPtr &graph) const override;
|
||||
|
||||
// Generate 'KernelMod' for all kernels and set 'KernelMod' into kernel,
|
||||
// 'KernelMod' is real executive object of kernel.
|
||||
void CreateKernel(const std::vector<CNodePtr> &nodes) const override;
|
||||
|
||||
// Adjust kernel graph before run graph, used in Graph Mode.
|
||||
void PreprocessBeforeRun(const FuncGraphPtr &graph) const override;
|
||||
|
||||
// Relevant function to allocate and free device memory of raw ptr.
|
||||
void *AllocateMemory(size_t size) const override;
|
||||
void FreeMemory(void *ptr) const override;
|
||||
|
||||
// Allocate continuous device memory according to size list.
|
||||
// Communication operators may need continuous memory for input and output
|
||||
// to optimize the communication performance.
|
||||
std::vector<void *> AllocateContinuousMemory(const std::vector<size_t> &size_list) const override;
|
||||
|
||||
// Create concrete device address according different device type.
|
||||
DeviceAddressPtr CreateDeviceAddress(void *const device_ptr, size_t device_size, const string &format, TypeId type_id,
|
||||
const ShapeVector &shape = ShapeVector()) const override;
|
||||
|
||||
// Launch graph, device such as Ascend support the whole graph sink to the device executing.
|
||||
bool LaunchGraph(const KernelGraphPtr &graph) const override;
|
||||
|
||||
// Launch a kernel via 'KernelMod' of the kernel.
|
||||
bool LaunchKernel(const CNodePtr &kernel, const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs) const override;
|
||||
|
||||
// Synchronize stream, device such as GPU and Ascend need stream to launch kernel asynchronously,
|
||||
// using 'SyncStream' to block thread and wait for completing all tasks in stream.
|
||||
// Devices that do not need stream could ignore the implementation of this function.
|
||||
bool SyncStream(size_t stream_id = 0) const override;
|
||||
|
||||
// Create and initialize bucket for every allreduce operator. Bucket is used in PyNative distributed training mode,
|
||||
// one bucket handles all resource to launch and sync allreduce operator.
|
||||
std::shared_ptr<Bucket> CreateBucket(uint32_t bucket_id, uint32_t bucket_size) const override;
|
||||
|
||||
// Unify the MindIR, the default behavior uses the common unified MindIR.
|
||||
void UnifyMindIR(const KernelGraphPtr &graph) const override;
|
||||
|
||||
// set rt_context_ to this thread to control device
|
||||
bool BindDeviceToCurrentThread() const override;
|
||||
|
||||
// Launch device aicpu library
|
||||
void LaunchDeviceLibrary() const;
|
||||
|
||||
private:
|
||||
// Graph loader interface
|
||||
void AllocateGraphMemory(const NotNull<KernelGraphPtr> &root_graph) const;
|
||||
void AssignInputMemory(const NotNull<KernelGraphPtr> &graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
|
||||
void LoadModel(const NotNull<KernelGraphPtr> &root_graph) const;
|
||||
void UpdateExecOrder(const KernelGraphPtr &graph) const;
|
||||
static bool IsGraphMode();
|
||||
bool PySyncRuning() const;
|
||||
bool MemoryCopyAsync(const CNodePtr &node, const vector<AddressPtr> &inputs, const vector<AddressPtr> &outputs) const;
|
||||
void InsertEventBeforeRunTask(const KernelGraphPtr &graph) const;
|
||||
void SetAtomicCleanToNodes(const KernelGraphPtr &graph,
|
||||
const std::map<CNodePtr, std::vector<CNodePtr>> &atomics_node) const;
|
||||
|
||||
void ReportErrorMessage() const;
|
||||
void ReportWarningMessage() const;
|
||||
void SetErrorManagerContext() const;
|
||||
|
||||
// Really create an ascend stream.
|
||||
bool CreateStream(void **stream) const override;
|
||||
|
||||
// Really destroy an ascend stream.
|
||||
bool DestroyStream(void *stream) const override;
|
||||
|
||||
// Kernel Runtime --- only for task sink
|
||||
AscendKernelRuntime *runtime_instance_{nullptr};
|
||||
std::shared_ptr<MemoryManager> mem_manager_{nullptr};
|
||||
// rank id of physical device
|
||||
uint32_t rank_id_{0};
|
||||
bool initialized_{false};
|
||||
|
||||
// LaunchGraph interface
|
||||
bool ExecuteGraph(const KernelGraphPtr &graph) const;
|
||||
// The ExecuteGraph is not thread safety specifically, it is not recommended that multiple threads access the same
|
||||
// func at the same time, so need the launch mutex when multiple threads launch the graph.
|
||||
mutable std::mutex launch_mutex_;
|
||||
// Using node to get it's atomics
|
||||
mutable std::map<CNodePtr, std::vector<CNodePtr>> node_atomics_;
|
||||
// Persistent cache for single op execution.
|
||||
// node_atomics_ will be cleaned up in CompileGraph.
|
||||
mutable std::map<CNodePtr, std::vector<CNodePtr>> node_atomics_persistent_cache_;
|
||||
mutable std::set<CNodePtr> nop_op_to_memcpy_;
|
||||
// Event for multi-stream
|
||||
mutable std::map<uint32_t, std::shared_ptr<DeviceEvent>> graph_event_;
|
||||
// Some NOP nodes have be hide in execution order, it doesn't have output device address, this function creates
|
||||
// output device address for these nodes, and the output device address is the same with input device address.
|
||||
void AssignOutputNopNodeDeviceAddress(const KernelGraphPtr &graph) const;
|
||||
bool LaunchAtomicClean(const CNodePtr &node, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) const;
|
||||
void *compute_stream_;
|
||||
void *communication_stream_;
|
||||
void *GetKernelStream(const CNodePtr &node) const;
|
||||
bool GetKernelRealInputs(const CNodePtr &kernel, const vector<AddressPtr> &inputs,
|
||||
std::vector<AddressPtr> *real_inputs) const;
|
||||
void PreprocessBeforeRunGraph(const KernelGraphPtr &graph) const;
|
||||
void PreprocessBeforeRunSingleOpGraph(const KernelGraphPtr &graph) const;
|
||||
AscendKernelRuntime *runtime_instance_{nullptr};
|
||||
};
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
|
|
|
@ -0,0 +1,145 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/ascend/hal/hardware/ascend_device_res_manager.h"
|
||||
#include "runtime/data_queue/data_queue_mgr.h"
|
||||
#include "runtime/rt.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace ascend {
|
||||
void AscendDeviceResManager::Initialize() {
|
||||
MS_LOG(INFO) << "Status record: Device resource manager Initialize start...";
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
runtime_instance_ = dynamic_cast<AscendKernelRuntime *>(
|
||||
device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id));
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
||||
if (!runtime_instance_->Init()) {
|
||||
MS_LOG(EXCEPTION) << "Kernel runtime init error.";
|
||||
}
|
||||
mem_manager_ = runtime_instance_->GetMemoryManager();
|
||||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||
|
||||
auto env_rank_id = common::GetEnv("RANK_ID");
|
||||
if (ms_context->get_param<bool>(MS_CTX_ENABLE_HCCL) && !env_rank_id.empty()) {
|
||||
// get actual rank id if it's distribution training case.
|
||||
rank_id_ = GetRankId();
|
||||
}
|
||||
|
||||
compute_stream_ = runtime_instance_->compute_stream();
|
||||
MS_EXCEPTION_IF_NULL(compute_stream_);
|
||||
communication_stream_ = runtime_instance_->communication_stream();
|
||||
MS_EXCEPTION_IF_NULL(communication_stream_);
|
||||
|
||||
MS_LOG(INFO) << "Status record: Device resource manager Initialize success.";
|
||||
}
|
||||
|
||||
void AscendDeviceResManager::Destroy() {
|
||||
MS_LOG(INFO) << "Status record: Device resource manager Destroy start...";
|
||||
if (DataQueueMgr::GetInstance().IsInit()) {
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(DataQueueMgr::GetInstance().Destroy(), "Could not destroy ascend data queue.");
|
||||
}
|
||||
|
||||
rank_id_ = 0;
|
||||
if (runtime_instance_) {
|
||||
runtime_instance_ = nullptr;
|
||||
}
|
||||
MS_LOG(INFO) << "Status record: Device resource manager Destroy success.";
|
||||
}
|
||||
|
||||
bool AscendDeviceResManager::BindDeviceToCurrentThread() const {
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
||||
runtime_instance_->SetContext();
|
||||
return true;
|
||||
}
|
||||
|
||||
void *AscendDeviceResManager::AllocateMemory(size_t size) const {
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
||||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||
runtime_instance_->SetContext();
|
||||
return mem_manager_->MallocMemFromMemPool(size, false);
|
||||
}
|
||||
|
||||
void AscendDeviceResManager::FreeMemory(void *ptr) const {
|
||||
MS_EXCEPTION_IF_NULL(ptr);
|
||||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||
mem_manager_->FreeMemFromMemPool(ptr);
|
||||
}
|
||||
|
||||
std::vector<void *> AscendDeviceResManager::AllocateContinuousMemory(const std::vector<size_t> &size_list) const {
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
||||
runtime_instance_->SetContext();
|
||||
std::vector<size_t> align_size_list;
|
||||
for (size_t i = 0; i < size_list.size(); i++) {
|
||||
auto align_size = device::MemoryManager::GetCommonAlignSize(size_list[i]);
|
||||
align_size_list.emplace_back(align_size);
|
||||
}
|
||||
return mem_manager_->MallocContinuousMemFromMemPool(align_size_list);
|
||||
}
|
||||
|
||||
DeviceAddressPtr AscendDeviceResManager::CreateDeviceAddress(void *const device_ptr, size_t device_size,
|
||||
const string &format, TypeId type_id,
|
||||
const ShapeVector &shape) const {
|
||||
auto device_address = std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id,
|
||||
device_context_->device_context_key().device_name_,
|
||||
device_context_->device_context_key().device_id_);
|
||||
if (shape.empty()) {
|
||||
MS_LOG(DEBUG) << "shape size is empty.";
|
||||
}
|
||||
device_address->set_host_shape(shape);
|
||||
return device_address;
|
||||
}
|
||||
|
||||
bool AscendDeviceResManager::SyncStream(size_t stream_id) const {
|
||||
auto iter = stream_ids_.find(stream_id);
|
||||
if (iter != stream_ids_.end()) {
|
||||
MS_EXCEPTION_IF_NULL(iter->second);
|
||||
auto ret = rtStreamSynchronize(iter->second);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Failed to synchronize ascend stream, ret[" << ret << "]";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
||||
return runtime_instance_->SyncStream();
|
||||
}
|
||||
|
||||
bool AscendDeviceResManager::CreateStream(void **stream) const {
|
||||
MS_EXCEPTION_IF_NULL(stream);
|
||||
auto ret = rtStreamCreate(stream, 0);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Failed to create ascend stream, ret[" << ret << "]";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AscendDeviceResManager::DestroyStream(void *stream) const {
|
||||
MS_EXCEPTION_IF_NULL(stream);
|
||||
auto ret = rtStreamDestroy(stream);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Failed to destroy ascend stream, ret[" << ret << "]";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,86 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_ASCEND_RES_MANAGER_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_ASCEND_RES_MANAGER_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include "runtime/hardware/device_context.h"
|
||||
#include "runtime/device/memory_manager.h"
|
||||
#include "plugin/device/ascend/hal/device/ascend_kernel_runtime.h"
|
||||
#include "plugin/device/ascend/hal/device/ascend_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace ascend {
|
||||
class AscendDeviceResManager : public DeviceResManager {
|
||||
public:
|
||||
AscendDeviceResManager() : mem_manager_(nullptr) {}
|
||||
~AscendDeviceResManager() override = default;
|
||||
|
||||
void Initialize() override;
|
||||
|
||||
void Destroy() override;
|
||||
|
||||
// set rt_context_ to this thread to control device
|
||||
bool BindDeviceToCurrentThread() const override;
|
||||
|
||||
// Relevant function to allocate and free device memory of raw ptr.
|
||||
void *AllocateMemory(size_t size) const override;
|
||||
void FreeMemory(void *ptr) const override;
|
||||
|
||||
// Allocate continuous device memory according to size list.
|
||||
// Communication operators may need continuous memory for input and output
|
||||
// to optimize the communication performance.
|
||||
std::vector<void *> AllocateContinuousMemory(const std::vector<size_t> &size_list) const override;
|
||||
|
||||
// Create concrete device address according different device type.
|
||||
DeviceAddressPtr CreateDeviceAddress(void *const device_ptr, size_t device_size, const string &format, TypeId type_id,
|
||||
const ShapeVector &shape = ShapeVector()) const override;
|
||||
|
||||
// Synchronize stream, device such as GPU and Ascend need stream to launch kernel asynchronously,
|
||||
// using 'SyncStream' to block thread and wait for completing all tasks in stream.
|
||||
// Devices that do not need stream could ignore the implementation of this function.
|
||||
bool SyncStream(size_t stream_id = 0) const override;
|
||||
|
||||
protected:
|
||||
// Really create an ascend stream.
|
||||
bool CreateStream(void **stream) const override;
|
||||
|
||||
// Really destroy an ascend stream.
|
||||
bool DestroyStream(void *stream) const override;
|
||||
|
||||
private:
|
||||
friend class AscendKernelExecutor;
|
||||
friend class AscendGraphExecutor;
|
||||
friend class AscendDeviceContext;
|
||||
|
||||
// rank id of physical device
|
||||
uint32_t rank_id_{0};
|
||||
void *compute_stream_;
|
||||
void *communication_stream_;
|
||||
// Kernel Runtime --- only for task sink
|
||||
AscendKernelRuntime *runtime_instance_{nullptr};
|
||||
std::shared_ptr<MemoryManager> mem_manager_{nullptr};
|
||||
};
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_ASCEND_RES_MANAGER_H_
|
|
@ -0,0 +1,354 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/ascend/hal/hardware/ascend_graph_executor.h"
|
||||
#include <unordered_map>
|
||||
#include <algorithm>
|
||||
#include "plugin/device/ascend/hal/hardware/ascend_utils.h"
|
||||
#include "backend/common/session/kernel_graph.h"
|
||||
#include "plugin/device/ascend/hal/device/kernel_build_ascend.h"
|
||||
#include "runtime/device/kernel_adjust.h"
|
||||
#include "plugin/device/ascend/hal/device/ascend_stream_assign.h"
|
||||
#include "plugin/device/ascend/hal/device/ascend_memory_adapter.h"
|
||||
#ifndef ENABLE_SECURITY
|
||||
#include "profiler/device/ascend/memory_profiling.h"
|
||||
using mindspore::profiler::ascend::MemoryProfiling;
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace ascend {
|
||||
using KernelGraph = mindspore::session::KernelGraph;
|
||||
|
||||
CNodePtr GetNextLabelSet(const std::vector<CNodePtr> &kernel_nodes, uint32_t index) {
|
||||
size_t node_sizes = kernel_nodes.size();
|
||||
if (index >= node_sizes - 1) {
|
||||
MS_LOG(EXCEPTION) << "there is no node after this node:" << kernel_nodes[index]->DebugString();
|
||||
}
|
||||
auto kernel = kernel_nodes[index + 1];
|
||||
if (common::AnfAlgo::GetCNodeName(kernel) != kLabelSetOpName) {
|
||||
MS_LOG(EXCEPTION) << "the node is not labelset follow labelgoto/labelswitch, node: "
|
||||
<< kernel_nodes[index]->DebugString();
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
std::vector<CNodePtr> HandleRecursiveCall(const std::vector<CNodePtr> &kernel_cnodes, const uint32_t &back_label,
|
||||
uint32_t *index, std::vector<CNodePtr> *back) {
|
||||
MS_EXCEPTION_IF_NULL(index);
|
||||
MS_EXCEPTION_IF_NULL(back);
|
||||
std::vector<CNodePtr> front;
|
||||
std::vector<CNodePtr> back_temp;
|
||||
bool back_flag = false;
|
||||
uint32_t i = *index;
|
||||
while (i < kernel_cnodes.size()) {
|
||||
if (!back_flag) {
|
||||
front.emplace_back(kernel_cnodes[i]);
|
||||
} else {
|
||||
back->emplace_back(kernel_cnodes[i]);
|
||||
}
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrRecursiveEnd, kernel_cnodes[i])) {
|
||||
*index = i;
|
||||
back->insert(back->end(), back_temp.begin(), back_temp.end());
|
||||
return front;
|
||||
}
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrRecursive, kernel_cnodes[i])) {
|
||||
back_flag = true;
|
||||
if (!common::AnfAlgo::IsLabelIndexInNode(kernel_cnodes[i], back_label)) {
|
||||
auto temp = HandleRecursiveCall(kernel_cnodes, back_label, &(++i), &back_temp);
|
||||
front.insert(front.end(), temp.begin(), temp.end());
|
||||
}
|
||||
}
|
||||
i++;
|
||||
}
|
||||
return front;
|
||||
}
|
||||
|
||||
void UnfoldRecursiveExecOrder(KernelGraph *kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
if (!kernel_graph->recursive_call()) {
|
||||
return;
|
||||
}
|
||||
auto kernel_cnodes = kernel_graph->mem_reuse_exec_order();
|
||||
std::vector<CNodePtr> mem_reuse_order;
|
||||
mem_reuse_order.reserve(kernel_cnodes.size());
|
||||
for (uint32_t i = 0; i < kernel_cnodes.size(); i++) {
|
||||
if (!common::AnfAlgo::HasNodeAttr(kAttrRecursiveStart, kernel_cnodes[i])) {
|
||||
mem_reuse_order.emplace_back(kernel_cnodes[i]);
|
||||
continue;
|
||||
}
|
||||
auto label_id = common::AnfAlgo::GetNodeAttr<uint32_t>(kernel_cnodes[i], kAttrLabelIndex);
|
||||
std::vector<CNodePtr> back;
|
||||
auto front = HandleRecursiveCall(kernel_cnodes, label_id, &i, &back);
|
||||
mem_reuse_order.insert(mem_reuse_order.end(), front.begin(), front.end());
|
||||
mem_reuse_order.insert(mem_reuse_order.end(), back.begin(), back.end());
|
||||
}
|
||||
kernel_graph->set_mem_reuse_exec_order(mem_reuse_order);
|
||||
}
|
||||
|
||||
void GetSubGraphExecOrder(const KernelGraph *kernel_graph, uint32_t index, const CNodePtr &back_node,
|
||||
std::vector<CNodePtr> *mem_reuse_order) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
MS_EXCEPTION_IF_NULL(mem_reuse_order);
|
||||
auto label_id = common::AnfAlgo::GetNodeAttr<uint32_t>(back_node, kAttrLabelIndex);
|
||||
auto kernel_cnodes = kernel_graph->execution_order();
|
||||
for (auto i = index; i < kernel_cnodes.size(); i++) {
|
||||
mem_reuse_order->emplace_back(kernel_cnodes[i]);
|
||||
if (common::AnfAlgo::IsLabelIndexInNode(kernel_cnodes[i], label_id)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void InitMemReuseExecOrder(KernelGraph *kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
if (!kernel_graph->subgraph_multi_call()) {
|
||||
return;
|
||||
}
|
||||
std::unordered_map<uint32_t, uint32_t> label_id_index_map;
|
||||
auto kernel_cnodes = kernel_graph->execution_order();
|
||||
std::vector<CNodePtr> mem_reuse_order;
|
||||
for (uint32_t i = 0; i < kernel_cnodes.size(); i++) {
|
||||
mem_reuse_order.emplace_back(kernel_cnodes[i]);
|
||||
if (common::AnfAlgo::CheckPrimitiveType(kernel_cnodes[i], prim::kPrimLabelSwitch) &&
|
||||
!common::AnfAlgo::HasNodeAttr(kAttrRecursive, kernel_cnodes[i]) &&
|
||||
!common::AnfAlgo::HasNodeAttr(kAttrReturn, kernel_cnodes[i])) {
|
||||
auto label_list = common::AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(kernel_cnodes[i], kAttrLabelSwitchList);
|
||||
for (auto label_id : label_list) {
|
||||
if (label_id_index_map.find(label_id) == label_id_index_map.end()) {
|
||||
continue;
|
||||
}
|
||||
auto back_node = GetNextLabelSet(kernel_cnodes, i);
|
||||
GetSubGraphExecOrder(kernel_graph, label_id_index_map[label_id], back_node, &mem_reuse_order);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (common::AnfAlgo::CheckPrimitiveType(kernel_cnodes[i], prim::kPrimLabelGoto) &&
|
||||
!common::AnfAlgo::HasNodeAttr(kAttrRecursive, kernel_cnodes[i]) &&
|
||||
!common::AnfAlgo::HasNodeAttr(kAttrReturn, kernel_cnodes[i])) {
|
||||
auto label_id = common::AnfAlgo::GetNodeAttr<uint32_t>(kernel_cnodes[i], kAttrLabelIndex);
|
||||
if (label_id_index_map.find(label_id) == label_id_index_map.end()) {
|
||||
continue;
|
||||
}
|
||||
auto back_node = GetNextLabelSet(kernel_cnodes, i);
|
||||
GetSubGraphExecOrder(kernel_graph, label_id_index_map[label_id], back_node, &mem_reuse_order);
|
||||
continue;
|
||||
}
|
||||
if (common::AnfAlgo::CheckPrimitiveType(kernel_cnodes[i], prim::kPrimLabelSet) &&
|
||||
!common::AnfAlgo::HasNodeAttr(kAttrRecursive, kernel_cnodes[i])) {
|
||||
auto label_id = common::AnfAlgo::GetNodeAttr<uint32_t>(kernel_cnodes[i], kAttrLabelIndex);
|
||||
if (label_id_index_map.find(label_id) != label_id_index_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "Two labelsets with same label id.";
|
||||
}
|
||||
label_id_index_map[label_id] = i;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
kernel_graph->set_mem_reuse_exec_order(mem_reuse_order);
|
||||
UnfoldRecursiveExecOrder(kernel_graph);
|
||||
}
|
||||
|
||||
void AscendGraphExecutor::Initialize() {
|
||||
res_manager_ = dynamic_cast<AscendDeviceResManager *>(device_context_->device_res_manager_.get());
|
||||
MS_EXCEPTION_IF_NULL(res_manager_);
|
||||
runtime_instance_ = res_manager_->runtime_instance_;
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
||||
mem_manager_ = res_manager_->mem_manager_;
|
||||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||
}
|
||||
|
||||
void AscendGraphExecutor::Destroy() {
|
||||
graph_event_.clear();
|
||||
mem_manager_ = nullptr;
|
||||
runtime_instance_ = nullptr;
|
||||
res_manager_ = nullptr;
|
||||
}
|
||||
|
||||
bool AscendGraphExecutor::RunGraph(const FuncGraphPtr &graph, const std::vector<tensor::Tensor> &inputs,
|
||||
std::vector<tensor::Tensor> *outputs,
|
||||
const std::map<string, string> &compile_options) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
MS_LOG(INFO) << "Status record: start launch graph. graph id: " << kernel_graph->graph_id();
|
||||
PROF_START(launch_graph);
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
||||
runtime_instance_->SetContext();
|
||||
SetErrorManagerContext();
|
||||
device::KernelAdjust::GetInstance().LoadDeviceLoopCtrlParameters(kernel_graph);
|
||||
auto ret = ExecuteGraph(kernel_graph);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "run task error!";
|
||||
ReportErrorMessage();
|
||||
return ret;
|
||||
}
|
||||
ReportWarningMessage();
|
||||
PROF_END(launch_graph);
|
||||
MS_LOG(INFO) << "Status record: end launch graph. graph id: " << kernel_graph->graph_id();
|
||||
return ret;
|
||||
}
|
||||
|
||||
void AscendGraphExecutor::PreprocessBeforeRun(const KernelGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
device::ascend::InsertAtomicCleanOps(graph->execution_order(), &node_atomics_);
|
||||
UpdateExecOrder(graph);
|
||||
device::KernelAdjust::GetInstance().InsertDeviceLoopCtrl(graph);
|
||||
device::KernelAdjust::GetInstance().ProcessLoopSink(graph);
|
||||
AscendStreamAssign::GetInstance().AssignStream(NOT_NULL(graph));
|
||||
#ifndef ENABLE_SECURITY
|
||||
// Insert profiling point, this function must be executed after assign stream.
|
||||
device::KernelAdjust::GetInstance().Profiling(NOT_NULL(graph.get()));
|
||||
#endif
|
||||
device_context_->kernel_executor_->CreateKernel(graph->execution_order());
|
||||
AllocateGraphMemory(NOT_NULL(graph));
|
||||
LoadModel(NOT_NULL(graph));
|
||||
AssignOutputNopNodeDeviceAddress(graph, device_context_);
|
||||
}
|
||||
|
||||
void AscendGraphExecutor::UpdateExecOrder(const KernelGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
std::vector<CNodePtr> new_orders;
|
||||
auto nodes = graph->execution_order();
|
||||
for (const auto &node : nodes) {
|
||||
if (node_atomics_.find(node) != node_atomics_.end()) {
|
||||
auto atomics = node_atomics_[node];
|
||||
(void)std::copy(atomics.begin(), atomics.end(), std::back_inserter(new_orders));
|
||||
}
|
||||
new_orders.push_back(node);
|
||||
}
|
||||
graph->set_execution_order(new_orders);
|
||||
node_atomics_.clear();
|
||||
}
|
||||
|
||||
void AscendGraphExecutor::AllocateGraphMemory(const NotNull<KernelGraphPtr> &root_graph) const {
|
||||
MS_LOG(INFO) << "Status record: start memory alloc. graph id: " << root_graph->graph_id();
|
||||
PROF_START(graph_memory_alloc);
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
||||
runtime_instance_->ClearGlobalIdleMem();
|
||||
std::set<KernelGraphPtr> memo;
|
||||
memo.clear();
|
||||
mem_manager_->ResetDynamicMemory();
|
||||
AssignInputMemory(root_graph, NOT_NULL(&memo));
|
||||
device::KernelAdjust::GetInstance().AssignLoopCtrlMemory(*root_graph.get());
|
||||
InitMemReuseExecOrder(root_graph.get().get());
|
||||
runtime_instance_->SetReuseCommunicationAddress(*root_graph.get());
|
||||
runtime_instance_->AssignStaticMemoryOutput(*root_graph.get());
|
||||
runtime_instance_->AssignDynamicMemory(*root_graph.get());
|
||||
runtime_instance_->UpdateRefNodeOutputMem(*root_graph.get());
|
||||
|
||||
PROF_END(graph_memory_alloc);
|
||||
MS_LOG(INFO) << "Status record: end memory alloc. graph id: " << root_graph->graph_id()
|
||||
<< ", Memory Statistics: " << device::ascend::AscendMemAdapter::GetInstance().DevMemStatistics();
|
||||
MS_LOG(INFO) << "The dynamic memory pool total size is: "
|
||||
<< device::ascend::AscendMemoryPool::GetInstance().TotalMemStatistics() / kMBToByte
|
||||
<< "M, total used size is "
|
||||
<< device::ascend::AscendMemoryPool::GetInstance().TotalUsedMemStatistics() / kMBToByte
|
||||
<< "M, used peak size is "
|
||||
<< device::ascend::AscendMemoryPool::GetInstance().UsedMemPeakStatistics() / kMBToByte << "M.";
|
||||
|
||||
#ifndef ENABLE_SECURITY
|
||||
if (MemoryProfiling::GetInstance().IsMemoryProfilingInitialized()) {
|
||||
uint64_t mem_size = runtime_instance_->GetMsUsedHbmSize();
|
||||
MemoryProfiling::GetInstance().SetDeviceMemSize(mem_size);
|
||||
if (MemoryProfiling::GetInstance().NeedSaveMemoryProfiling()) {
|
||||
MemoryProfiling::GetInstance().SaveMemoryProfiling();
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void AscendGraphExecutor::AssignInputMemory(const NotNull<KernelGraphPtr> &graph,
|
||||
NotNull<std::set<KernelGraphPtr> *> const memo) const {
|
||||
if (memo->find(graph) != memo->end()) {
|
||||
return;
|
||||
}
|
||||
memo->insert(graph.get());
|
||||
|
||||
MS_LOG(INFO) << "Start to assign static memory for Parameter and Value node in graph: " << graph->graph_id();
|
||||
runtime_instance_->AssignStaticMemoryInput(*graph.get());
|
||||
runtime_instance_->AssignStaticMemoryValueNode(*graph.get());
|
||||
for (auto &child_graph : graph->child_graph_order()) {
|
||||
AssignInputMemory(NOT_NULL(child_graph.lock()), memo);
|
||||
}
|
||||
MS_LOG(INFO) << "Finish assigning static memory for Parameter and Value node in graph: " << graph->graph_id();
|
||||
}
|
||||
|
||||
void AscendGraphExecutor::LoadModel(const NotNull<KernelGraphPtr> &root_graph) const {
|
||||
MS_LOG(INFO) << "Status record: start load model. graph id: " << root_graph->graph_id();
|
||||
PROF_START(load_model);
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
||||
bool ret_ok = runtime_instance_->Load(*root_graph.get(), true);
|
||||
if (!ret_ok) {
|
||||
MS_LOG(EXCEPTION) << "Load task error!";
|
||||
}
|
||||
PROF_END(load_model);
|
||||
MS_LOG(INFO) << "Status record: end load model. graph id: " << root_graph->graph_id();
|
||||
}
|
||||
|
||||
bool AscendGraphExecutor::ExecuteGraph(const KernelGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
const uint64_t kUSecondInSecond = 1000000;
|
||||
bool ret = false;
|
||||
if (graph->is_graph_run_mode()) {
|
||||
InsertEventBeforeRunTask(graph);
|
||||
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
auto start_time = std::chrono::steady_clock::now();
|
||||
#else
|
||||
struct timeval start_time {};
|
||||
struct timeval end_time {};
|
||||
(void)gettimeofday(&start_time, nullptr);
|
||||
#endif
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
||||
{
|
||||
std::lock_guard<std::mutex> locker(launch_mutex_);
|
||||
ret = runtime_instance_->RunTask(*graph);
|
||||
}
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
auto end_time = std::chrono::steady_clock::now();
|
||||
std::chrono::duration<double, std::ratio<1, kUSecondInSecond>> cost = end_time - start_time;
|
||||
MS_LOG(INFO) << "Call MS Run Success in " << cost.count() << " us";
|
||||
#else
|
||||
(void)gettimeofday(&end_time, nullptr);
|
||||
uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
|
||||
cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
|
||||
MS_LOG(INFO) << "Call MS Run Success in " << cost << " us";
|
||||
#endif
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << graph->ToString() << " does not sink, should launch kernels";
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
void AscendGraphExecutor::InsertEventBeforeRunTask(const KernelGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (!graph->is_graph_run_mode() || graph->is_dynamic_shape()) {
|
||||
return;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Insert event between PyNative and Graph";
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
||||
auto model_stream = runtime_instance_->GetModelStream(graph->graph_id());
|
||||
auto compute_event = runtime_instance_->CreateDeviceEvent();
|
||||
MS_EXCEPTION_IF_NULL(compute_event);
|
||||
compute_event->set_wait_stream(model_stream);
|
||||
compute_event->set_record_stream(res_manager_->compute_stream_);
|
||||
compute_event->RecordEvent();
|
||||
compute_event->WaitEvent();
|
||||
graph_event_[graph->graph_id()] = compute_event;
|
||||
}
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,76 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_ASCEND_GRAPH_EXECUTOR_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_ASCEND_GRAPH_EXECUTOR_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include "runtime/hardware/device_context.h"
|
||||
#include "runtime/device/memory_manager.h"
|
||||
#include "plugin/device/ascend/hal/device/ascend_kernel_runtime.h"
|
||||
#include "plugin/device/ascend/hal/hardware/ascend_device_res_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace ascend {
|
||||
class AscendGraphExecutor : public GraphExecutor {
|
||||
public:
|
||||
AscendGraphExecutor() = default;
|
||||
~AscendGraphExecutor() override = default;
|
||||
|
||||
void Initialize();
|
||||
void Destroy();
|
||||
|
||||
// Launch graph, device such as Ascend support the whole graph sink to the device executing.
|
||||
bool RunGraph(const FuncGraphPtr &graph, const std::vector<tensor::Tensor> &inputs,
|
||||
std::vector<tensor::Tensor> *outputs, const std::map<string, string> &compile_options) override;
|
||||
|
||||
// Adjust kernel graph before run graph, used in Graph Mode.
|
||||
void PreprocessBeforeRun(const KernelGraphPtr &graph) const;
|
||||
|
||||
private:
|
||||
// compile graph interface
|
||||
void UpdateExecOrder(const KernelGraphPtr &graph) const;
|
||||
void AllocateGraphMemory(const NotNull<KernelGraphPtr> &root_graph) const;
|
||||
void AssignInputMemory(const NotNull<KernelGraphPtr> &graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
|
||||
void LoadModel(const NotNull<KernelGraphPtr> &root_graph) const;
|
||||
|
||||
// LaunchGraph interface
|
||||
void InsertEventBeforeRunTask(const KernelGraphPtr &graph) const;
|
||||
bool ExecuteGraph(const KernelGraphPtr &graph) const;
|
||||
|
||||
// Kernel Runtime --- only for task sink
|
||||
AscendKernelRuntime *runtime_instance_{nullptr};
|
||||
std::shared_ptr<MemoryManager> mem_manager_{nullptr};
|
||||
|
||||
// The ExecuteGraph is not thread safety specifically, it is not recommended that multiple threads access the same
|
||||
// func at the same time, so need the launch mutex when multiple threads launch the graph.
|
||||
mutable std::mutex launch_mutex_;
|
||||
// Using node to get its atomics
|
||||
mutable std::map<CNodePtr, std::vector<CNodePtr>> node_atomics_;
|
||||
// Event for multi-stream
|
||||
mutable std::map<uint32_t, std::shared_ptr<DeviceEvent>> graph_event_;
|
||||
AscendDeviceResManager *res_manager_{nullptr};
|
||||
};
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_ASCEND_GRAPH_EXECUTOR_H_
|
|
@ -0,0 +1,477 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/ascend/hal/hardware/ascend_kernel_executor.h"
|
||||
#include <algorithm>
|
||||
#include "plugin/device/ascend/hal/hardware/ascend_utils.h"
|
||||
#include "plugin/device/ascend/hal/hardware/ascend_graph_optimization.h"
|
||||
#include "plugin/device/ascend/hal/device/kernel_select_ascend.h"
|
||||
#include "plugin/device/ascend/hal/device/kernel_build_ascend.h"
|
||||
#include "plugin/device/ascend/kernel/aicpu/aicpu_kernel_load.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_compile.h"
|
||||
#include "plugin/device/ascend/hal/device/ascend_bucket.h"
|
||||
#include "plugin/device/ascend/hal/device/ascend_stream_assign.h"
|
||||
#include "include/common/utils/parallel_context.h"
|
||||
#include "kernel/ascend_kernel_mod.h"
|
||||
#include "acl/acl_rt.h"
|
||||
|
||||
#ifndef ENABLE_SECURITY
|
||||
#include "debug/data_dump/dump_json_parser.h"
|
||||
#include "toolchain/adx_datadump_server.h"
|
||||
#include "toolchain/adx_datadump_callback.h"
|
||||
#include "include/common/debug/anf_ir_dump.h"
|
||||
#include "include/common/debug/dump_proto.h"
|
||||
#include "debug/data_dump/e2e_dump.h"
|
||||
#include "debug/debugger/debugger_utils.h"
|
||||
#include "profiler/device/ascend/memory_profiling.h"
|
||||
#include "plugin/device/ascend/hal/device/profiling/profiling_manager.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "profiler/device/ascend/pynative_profiling.h"
|
||||
#include "profiler/device/ascend/ascend_profiling.h"
|
||||
|
||||
using Adx::AdxRegDumpProcessCallBack;
|
||||
using mindspore::device::ascend::ProfilingManager;
|
||||
using mindspore::profiler::ascend::MemoryProfiling;
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace ascend {
|
||||
constexpr size_t kAtomicCleanInputSize = 2;
|
||||
namespace {
|
||||
/*
|
||||
* Feature group: Dump.
|
||||
* Target device group: Ascend.
|
||||
* Runtime category: MindRT.
|
||||
* Description: Parse config json file and register callback to adx.
|
||||
*/
|
||||
#ifndef ENABLE_SECURITY
|
||||
void DumpInit(uint32_t device_id) {
|
||||
auto &json_parser = DumpJsonParser::GetInstance();
|
||||
json_parser.Parse();
|
||||
json_parser.CopyDumpJsonToDir(device_id);
|
||||
json_parser.CopyHcclJsonToDir(device_id);
|
||||
json_parser.CopyMSCfgJsonToDir(device_id);
|
||||
if (json_parser.async_dump_enabled()) {
|
||||
#if !(defined(ENABLE_TEST) || defined(ENABLE_TESTCASES))
|
||||
// register callback to adx
|
||||
if (json_parser.FileFormatIsNpy()) {
|
||||
AdxRegDumpProcessCallBack(DumpDataCallBack);
|
||||
}
|
||||
#endif
|
||||
if (AdxDataDumpServerInit() != 0) {
|
||||
MS_LOG(EXCEPTION) << "Adx data dump server init failed";
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
} // namespace
|
||||
|
||||
void AscendKernelExecutor::Initialize() {
|
||||
kernel::ascend::TbeKernelCompileManager::GetInstance().TbeInitialize();
|
||||
res_manager_ = dynamic_cast<AscendDeviceResManager *>(device_context_->device_res_manager_.get());
|
||||
MS_EXCEPTION_IF_NULL(res_manager_);
|
||||
graph_executor_ = dynamic_cast<AscendGraphExecutor *>(device_context_->graph_executor_.get());
|
||||
MS_EXCEPTION_IF_NULL(graph_executor_);
|
||||
#ifndef ENABLE_SECURITY
|
||||
DumpInit(res_manager_->rank_id_);
|
||||
#endif
|
||||
}
|
||||
|
||||
void AscendKernelExecutor::Destroy() {
|
||||
AscendGraphOptimization::GetInstance().Reset();
|
||||
res_manager_ = nullptr;
|
||||
graph_executor_ = nullptr;
|
||||
}
|
||||
|
||||
void AscendKernelExecutor::UnifyMindIR(const KernelGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
AscendGraphOptimization::GetInstance().UnifyMindIR(graph);
|
||||
}
|
||||
|
||||
void AscendKernelExecutor::OptimizeGraph(const FuncGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
if (kernel_graph->is_from_single_op()) {
|
||||
AscendGraphOptimization::GetInstance().OptimizeSingleOpGraph(kernel_graph);
|
||||
} else {
|
||||
AscendGraphOptimization::GetInstance().OptimizeGraph(kernel_graph);
|
||||
}
|
||||
}
|
||||
|
||||
// Before creating the kernel, check whether the node has completed the operator selection. If not, the operator
|
||||
// selection needs to be performed to set kernel info.
|
||||
void SetKernelInfoBeforeCreateKernel(const std::vector<CNodePtr> &nodes) {
|
||||
// Check whether the node has completed kernel selection.
|
||||
for (const auto &node : nodes) {
|
||||
if (AnfAlgo::GetSelectKernelBuildInfo(node) != nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Kernel selection process.
|
||||
auto [status, msg, etype] = SelectKernelInfoWithMsg(node);
|
||||
if (status == device::ascend::kNoMatched) {
|
||||
MS_EXCEPTION(etype) << msg;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AscendKernelExecutor::CreateKernel(const std::vector<CNodePtr> &nodes) const {
|
||||
SetKernelInfoBeforeCreateKernel(nodes);
|
||||
|
||||
MS_LOG(INFO) << "Status record: start create kernel.";
|
||||
PROF_START(create_kernel);
|
||||
auto ret = device::ascend::KernelBuild(nodes);
|
||||
if (!ret) {
|
||||
MS_LOG(EXCEPTION) << "Kernel build error.";
|
||||
}
|
||||
PROF_END(create_kernel);
|
||||
MS_LOG(INFO) << "Status record: end create kernel.";
|
||||
}
|
||||
|
||||
void AscendKernelExecutor::LaunchDeviceLibrary() const {
|
||||
MS_LOG(INFO) << "Status record: start launch device library.";
|
||||
auto ret = mindspore::kernel::AicpuOpKernelLoad::GetInstance().LaunchAicpuKernelSo();
|
||||
if (!ret) {
|
||||
MS_LOG(EXCEPTION) << "Cust aicpu kernel so load failed.";
|
||||
}
|
||||
MS_LOG(INFO) << "Status record: end launch device library.";
|
||||
}
|
||||
|
||||
void AscendKernelExecutor::SetAtomicCleanToNodes(const KernelGraphPtr &graph,
|
||||
const std::map<CNodePtr, std::vector<CNodePtr>> &atomics_node) const {
|
||||
// don't clear node_atomics_ in the end, since atomic_clean_nodes_ in kernel.h is weakptr
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto nodes = graph->execution_order();
|
||||
for (const auto &node : nodes) {
|
||||
auto it = atomics_node.find(node);
|
||||
if (it != atomics_node.end()) {
|
||||
const auto &atomics = it->second;
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(node);
|
||||
auto ascend_kernel_mod = dynamic_cast<kernel::AscendKernelMod *>(kernel_mod);
|
||||
if (ascend_kernel_mod != nullptr) {
|
||||
ascend_kernel_mod->SetAtomicCleanNodes(atomics);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AscendKernelExecutor::PreprocessBeforeRun(const FuncGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
if (kernel_graph->is_from_single_op()) {
|
||||
PreprocessBeforeRunSingleOpGraph(kernel_graph);
|
||||
} else {
|
||||
PreprocessBeforeRunGraph(kernel_graph);
|
||||
}
|
||||
}
|
||||
|
||||
void AscendKernelExecutor::PreprocessBeforeRunGraph(const KernelGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_LOG(INFO) << "Status record: start preprocess before run graph. graph id: " << graph->graph_id();
|
||||
PROF_START(preprocess_before_run_graph);
|
||||
auto ascend_instance = profiler::ascend::AscendProfiler::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ascend_instance);
|
||||
if (graph->is_dynamic_shape()) {
|
||||
ascend_instance->SetNetDynamicShapeStatus();
|
||||
}
|
||||
SetErrorManagerContext();
|
||||
try {
|
||||
if (graph->is_graph_run_mode()) {
|
||||
graph_executor_->PreprocessBeforeRun(graph);
|
||||
} else if (graph->is_dynamic_shape() && (IsGraphMode() || graph->has_flag(kFlagPyNativeRunInGraph))) {
|
||||
device::ascend::InsertAtomicCleanOps(graph->execution_order(), &node_atomics_);
|
||||
SetAtomicCleanToNodes(graph, node_atomics_); // graph mode may can do it too, instead of update execorder
|
||||
opt::DynamicShapeConvertPass(graph);
|
||||
AscendStreamAssign::GetInstance().AssignStream(NOT_NULL(graph));
|
||||
AssignOutputNopNodeDeviceAddress(graph, device_context_);
|
||||
LaunchDeviceLibrary();
|
||||
} else {
|
||||
PreprocessBeforeRunSingleOpGraph(graph);
|
||||
AscendStreamAssign::GetInstance().AssignStream(NOT_NULL(graph));
|
||||
CreateKernel(graph->execution_order());
|
||||
MS_EXCEPTION_IF_NULL(res_manager_->runtime_instance_);
|
||||
res_manager_->runtime_instance_->SetKernelModRtStream(NOT_NULL(graph));
|
||||
}
|
||||
} catch (const std::exception &e) {
|
||||
ReportErrorMessage();
|
||||
MS_LOG(EXCEPTION) << "Preprocess failed before run graph " << graph->graph_id() << ", \nerror msg: " << e.what();
|
||||
}
|
||||
|
||||
const std::vector<CNodePtr> &kernels = graph->execution_order();
|
||||
for (const auto &kernel : kernels) {
|
||||
common::AnfAlgo::SetNodeAttr(kAttrMSFunction, MakeValue(true), kernel);
|
||||
}
|
||||
|
||||
PROF_END(preprocess_before_run_graph);
|
||||
MS_LOG(INFO) << "Status record: end preprocess before run graph. graph id: " << graph->graph_id();
|
||||
}
|
||||
|
||||
void AscendKernelExecutor::PreprocessBeforeRunSingleOpGraph(const KernelGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
const auto &nodes = graph->execution_order();
|
||||
|
||||
for (const auto &node : nodes) {
|
||||
// Remove placeholder
|
||||
auto op_name = common::AnfAlgo::GetCNodeName(node);
|
||||
static const std::set<std::string> place_holder_nodes = {kDynamicRNNOpName, kDynamicGRUV2OpName};
|
||||
auto iter = place_holder_nodes.find(op_name);
|
||||
if (iter != place_holder_nodes.end()) {
|
||||
auto none_index = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, kAttrPlaceHolderIndex);
|
||||
// Remove seq_length
|
||||
auto input_num = common::AnfAlgo::GetInputTensorNum(node);
|
||||
std::vector<AnfNodePtr> new_inputs = {common::AnfAlgo::GetCNodePrimitiveNode(node)};
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
auto item = std::find(none_index.begin(), none_index.end(), i);
|
||||
if (item == none_index.end()) {
|
||||
auto input_node = common::AnfAlgo::GetInputNode(node, i);
|
||||
new_inputs.emplace_back(input_node);
|
||||
}
|
||||
}
|
||||
node->set_inputs(new_inputs);
|
||||
}
|
||||
|
||||
// Save the nop_op that needs to be memcpy
|
||||
static mindspore::HashSet<std::string> nop_nodes = {prim::kPrimReshape->name(), prim::kPrimExpandDims->name(),
|
||||
prim::kPrimSqueeze->name(), prim::kPrimFlatten->name(),
|
||||
prim::kPrimFlattenGrad->name()};
|
||||
// If the 2nd input of reshape is not a value node, then there are two inputs to select the host reshape operator
|
||||
bool is_host_reshape_op = false;
|
||||
if (op_name == prim::kPrimReshape->name()) {
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(node);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
is_host_reshape_op = kernel_mod->GetKernelModType() == kernel::KernelModType::HostKernelMod;
|
||||
}
|
||||
bool nop_op_is_not_dynamic_shape = !graph->is_dynamic_shape() && nop_nodes.find(op_name) != nop_nodes.end();
|
||||
bool is_transpose_nop = op_name == prim::kPrimTranspose->name() && common::AnfAlgo::HasNodeAttr(kAttrNopOp, node);
|
||||
if (is_transpose_nop || (nop_op_is_not_dynamic_shape && !is_host_reshape_op)) {
|
||||
nop_op_to_memcpy_.insert(node);
|
||||
}
|
||||
}
|
||||
|
||||
device::ascend::InsertAtomicCleanOps(nodes, &node_atomics_persistent_cache_);
|
||||
std::vector<CNodePtr> atomic_nodes;
|
||||
for (const auto &node : nodes) {
|
||||
auto iter = node_atomics_persistent_cache_.find(node);
|
||||
if (iter != node_atomics_persistent_cache_.end()) {
|
||||
const auto &atomics = iter->second;
|
||||
std::copy(atomics.begin(), atomics.end(), std::back_inserter(atomic_nodes));
|
||||
}
|
||||
}
|
||||
|
||||
SetAtomicCleanToNodes(graph, node_atomics_persistent_cache_);
|
||||
CreateKernel(atomic_nodes);
|
||||
LaunchDeviceLibrary();
|
||||
}
|
||||
|
||||
std::shared_ptr<Bucket> AscendKernelExecutor::CreateBucket(uint32_t bucket_id, uint32_t bucket_size) const {
|
||||
auto bucket = std::make_shared<AscendBucket>(bucket_id, bucket_size);
|
||||
MS_EXCEPTION_IF_NULL(bucket);
|
||||
|
||||
// For data-parallel, there is no communication in forward and backward process, the only communication ops arise
|
||||
// from this allreduce bucket. All the ops in forward and backward process are assigned on the compute stream and
|
||||
// allreduce for gradients is assigned on communication stream.
|
||||
// But for semi/auto_parallel mode, there will be communication ops in forward and backward process. To avoid stream
|
||||
// sync error, for semi/auto_parallel mode, the allreduce for gradients is assigned on compute stream as well.
|
||||
auto parallel_context = parallel::ParallelContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(parallel_context);
|
||||
auto parallel_mode = parallel_context->parallel_mode();
|
||||
if (parallel_mode == parallel::kAutoParallel || parallel_mode == parallel::kSemiAutoParallel) {
|
||||
bucket->Init({res_manager_->compute_stream_}, {res_manager_->compute_stream_});
|
||||
} else {
|
||||
bucket->Init({res_manager_->compute_stream_}, {res_manager_->communication_stream_});
|
||||
}
|
||||
return bucket;
|
||||
}
|
||||
|
||||
bool AscendKernelExecutor::PySyncRuning() const {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if ((ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) &&
|
||||
ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE) && !res_manager_->SyncStream()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AscendKernelExecutor::MemoryCopyAsync(const CNodePtr &node, const vector<AddressPtr> &inputs,
|
||||
const vector<AddressPtr> &outputs) const {
|
||||
MS_LOG(DEBUG) << "Launch MemoryCopyAsync instead for kernel " << node->fullname_with_scope();
|
||||
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||
MS_LOG(ERROR) << "Kernel " << node->fullname_with_scope() << " input output size should be 1 but"
|
||||
<< " input size is:" << inputs.size() << " output size is:" << outputs.size();
|
||||
return false;
|
||||
}
|
||||
|
||||
aclError status = aclrtMemcpyAsync(outputs[0]->addr, outputs[0]->size, inputs[0]->addr, inputs[0]->size,
|
||||
ACL_MEMCPY_DEVICE_TO_DEVICE, res_manager_->compute_stream_);
|
||||
if (status != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "MemCpyAsync op aclrtMemcpyAsync failed, ret:" << status;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void *AscendKernelExecutor::GetKernelStream(const CNodePtr &node) const {
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(node);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
|
||||
return res_manager_->compute_stream_;
|
||||
} else if (common::AnfAlgo::HasNodeAttr(kAttrStream, node)) {
|
||||
auto stream_id = common::AnfAlgo::GetNodeAttr<size_t>(node, kAttrStream);
|
||||
auto iter = res_manager_->stream_ids_.find(stream_id);
|
||||
if (iter == res_manager_->stream_ids_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Can not find stream for stream id: " << stream_id;
|
||||
}
|
||||
void *stream = iter->second;
|
||||
MS_EXCEPTION_IF_NULL(stream);
|
||||
return stream;
|
||||
} else {
|
||||
auto stream = kernel_mod->stream();
|
||||
if (stream == nullptr) {
|
||||
stream = res_manager_->compute_stream_;
|
||||
MS_LOG(INFO) << "Assign default compute stream for node " << node->fullname_with_scope();
|
||||
}
|
||||
return stream;
|
||||
}
|
||||
}
|
||||
|
||||
bool AscendKernelExecutor::GetKernelRealInputs(const CNodePtr &kernel, const vector<AddressPtr> &inputs,
|
||||
std::vector<AddressPtr> *real_inputs) const {
|
||||
auto input_num = common::AnfAlgo::GetInputTensorNum(kernel);
|
||||
if (input_num != inputs.size()) {
|
||||
MS_LOG(ERROR) << "Input num is " << input_num << " but input address num is " << inputs.size();
|
||||
return false;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
auto real_index = AnfAlgo::GetRealInputIndex(kernel, i);
|
||||
if (real_index >= input_num) {
|
||||
MS_LOG(ERROR) << "Total input num is " << input_num << " but get real_index " << real_index;
|
||||
return false;
|
||||
}
|
||||
real_inputs->push_back(inputs[real_index]);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AscendKernelExecutor::LaunchKernel(const CNodePtr &kernel, const vector<AddressPtr> &inputs,
|
||||
const vector<AddressPtr> &workspace, const vector<AddressPtr> &outputs) const {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
auto graph_id = AnfAlgo::GetGraphId(kernel.get());
|
||||
auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
KernelType kernel_type = AnfAlgo::GetKernelType(kernel);
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
MS_LOG(DEBUG) << "Launch kernel: " << kernel->fullname_with_scope();
|
||||
res_manager_->BindDeviceToCurrentThread();
|
||||
|
||||
std::vector<AddressPtr> real_inputs;
|
||||
bool ret = GetKernelRealInputs(kernel, inputs, &real_inputs);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "Get real input fail for kernel " << kernel->fullname_with_scope();
|
||||
return false;
|
||||
}
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
|
||||
bool is_dynamic_shape = common::AnfAlgo::IsDynamicShape(kernel);
|
||||
if (!is_dynamic_shape || !(common::AnfAlgo::GetBooleanAttr(kernel, kAttrMSFunction))) {
|
||||
std::lock_guard<std::mutex> locker(launch_mutex_);
|
||||
// launch atomic clean
|
||||
if (!LaunchAtomicClean(kernel, workspace, outputs)) {
|
||||
MS_LOG(ERROR) << "Launch AtomicClean failed, pre kernel full name: " << kernel->fullname_with_scope();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// launch kernel
|
||||
if (nop_op_to_memcpy_.find(kernel) != nop_op_to_memcpy_.end()) {
|
||||
MemoryCopyAsync(kernel, real_inputs, outputs);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Launch kernel " << kernel->fullname_with_scope();
|
||||
auto stream = GetKernelStream(kernel);
|
||||
#ifndef ENABLE_SECURITY
|
||||
auto profiler_inst = profiler::ascend::PynativeProfiler::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(profiler_inst);
|
||||
std::thread::id t_id = std::this_thread::get_id();
|
||||
(void)profiler_inst->OpDataProducerBegin(res_manager_->runtime_instance_, stream, t_id,
|
||||
kernel->fullname_with_scope(), is_dynamic_shape);
|
||||
#endif
|
||||
ret = kernel_mod->Launch(real_inputs, workspace, outputs, stream);
|
||||
#ifndef ENABLE_SECURITY
|
||||
(void)profiler_inst->OpDataProducerEnd(t_id, is_dynamic_shape);
|
||||
#endif
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "Launch kernel failed, kernel full name: " << kernel->fullname_with_scope();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
auto ascend_instance = profiler::ascend::AscendProfiler::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ascend_instance);
|
||||
if (ascend_instance->GetNetDynamicShapeStatus() && ascend_instance->GetProfilingEnableFlag()) {
|
||||
ascend_instance->GetNodeTaskIdStreamId(kernel, graph_id, device_id, kernel_type);
|
||||
}
|
||||
|
||||
return PySyncRuning();
|
||||
}
|
||||
|
||||
bool AscendKernelExecutor::LaunchAtomicClean(const CNodePtr &node, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) const {
|
||||
auto iter = node_atomics_persistent_cache_.find(node);
|
||||
if (iter == node_atomics_persistent_cache_.end()) {
|
||||
return true;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Launch atomic clean for kernel " << node->fullname_with_scope();
|
||||
auto atomic_node = iter->second.at(0);
|
||||
vector<AddressPtr> atomic_inputs;
|
||||
// The output addr need to clean
|
||||
MS_EXCEPTION_IF_NULL(atomic_node);
|
||||
if (atomic_node->inputs().size() != kAtomicCleanInputSize) {
|
||||
MS_LOG(EXCEPTION) << "Atomic Addr clean Node Input nodes not equal 2.";
|
||||
}
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, node)) {
|
||||
auto clean_output_indexes = common::AnfAlgo::GetNodeAttr<std::vector<size_t>>(node, kAttrAtomicOutputIndexs);
|
||||
for (auto output_index : clean_output_indexes) {
|
||||
if (output_index >= outputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid output_index:" << output_index << " except less than " << outputs.size();
|
||||
}
|
||||
atomic_inputs.push_back(outputs[output_index]);
|
||||
}
|
||||
}
|
||||
|
||||
// The workspace addr need to clean
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, node)) {
|
||||
auto clean_workspace_indexes = common::AnfAlgo::GetNodeAttr<std::vector<size_t>>(node, kAttrAtomicWorkspaceIndexs);
|
||||
for (auto workspace_index : clean_workspace_indexes) {
|
||||
if (workspace_index >= workspace.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid workspace_index:" << workspace_index << " except less than " << workspace.size();
|
||||
}
|
||||
atomic_inputs.push_back(workspace[workspace_index]);
|
||||
}
|
||||
}
|
||||
// Launch Atomic Node
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(atomic_node);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
return kernel_mod->Launch(atomic_inputs, {}, {}, GetKernelStream(node));
|
||||
}
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,100 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_ASCEND_KERNEL_EXECUTOR_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_ASCEND_KERNEL_EXECUTOR_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include "runtime/hardware/device_context.h"
|
||||
#include "runtime/device/memory_manager.h"
|
||||
#include "plugin/device/ascend/hal/device/ascend_kernel_runtime.h"
|
||||
#include "plugin/device/ascend/hal/device/ascend_device_address.h"
|
||||
#include "plugin/device/ascend/hal/hardware/ascend_device_res_manager.h"
|
||||
#include "plugin/device/ascend/hal/hardware/ascend_graph_executor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace ascend {
|
||||
class AscendKernelExecutor : public DeprecatedKernelExecutor {
|
||||
public:
|
||||
AscendKernelExecutor() = default;
|
||||
~AscendKernelExecutor() override = default;
|
||||
|
||||
void Initialize();
|
||||
void Destroy();
|
||||
|
||||
// Optimize the kernel graph for graph mode.
|
||||
void OptimizeGraph(const FuncGraphPtr &graph) const override;
|
||||
|
||||
// Generate 'KernelMod' for all kernels and set 'KernelMod' into kernel,
|
||||
// 'KernelMod' is real executive object of kernel.
|
||||
void CreateKernel(const std::vector<CNodePtr> &nodes) const override;
|
||||
|
||||
// Adjust kernel graph before run graph, used in Graph Mode.
|
||||
void PreprocessBeforeRun(const FuncGraphPtr &graph) const override;
|
||||
|
||||
// Launch a kernel via 'KernelMod' of the kernel.
|
||||
bool LaunchKernel(const CNodePtr &kernel, const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs) const override;
|
||||
|
||||
// Unify the MindIR, the default behavior uses the common unified MindIR.
|
||||
void UnifyMindIR(const KernelGraphPtr &graph) const override;
|
||||
|
||||
// Get rank id for distributed training.
|
||||
uint32_t GetRankID() const override { return res_manager_->rank_id_; }
|
||||
|
||||
// Create and initialize bucket for every allreduce operator. Bucket is used in PyNative distributed training mode,
|
||||
// one bucket handles all resource to launch and sync allreduce operator.
|
||||
std::shared_ptr<Bucket> CreateBucket(uint32_t bucket_id, uint32_t bucket_size) const override;
|
||||
|
||||
private:
|
||||
// Launch device aicpu library
|
||||
void LaunchDeviceLibrary() const;
|
||||
|
||||
void SetAtomicCleanToNodes(const KernelGraphPtr &graph,
|
||||
const std::map<CNodePtr, std::vector<CNodePtr>> &atomics_node) const;
|
||||
|
||||
// launch
|
||||
bool PySyncRuning() const;
|
||||
bool MemoryCopyAsync(const CNodePtr &node, const vector<AddressPtr> &inputs, const vector<AddressPtr> &outputs) const;
|
||||
bool LaunchAtomicClean(const CNodePtr &node, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) const;
|
||||
|
||||
void *GetKernelStream(const CNodePtr &node) const;
|
||||
bool GetKernelRealInputs(const CNodePtr &kernel, const vector<AddressPtr> &inputs,
|
||||
std::vector<AddressPtr> *real_inputs) const;
|
||||
void PreprocessBeforeRunGraph(const KernelGraphPtr &graph) const;
|
||||
void PreprocessBeforeRunSingleOpGraph(const KernelGraphPtr &graph) const;
|
||||
|
||||
// Using node to get it's atomics
|
||||
mutable std::map<CNodePtr, std::vector<CNodePtr>> node_atomics_;
|
||||
// Persistent cache for single op execution.
|
||||
// node_atomics_ will be cleaned up in CompileGraph.
|
||||
mutable std::map<CNodePtr, std::vector<CNodePtr>> node_atomics_persistent_cache_;
|
||||
mutable std::set<CNodePtr> nop_op_to_memcpy_;
|
||||
mutable std::mutex launch_mutex_;
|
||||
AscendDeviceResManager *res_manager_{nullptr};
|
||||
AscendGraphExecutor *graph_executor_{nullptr};
|
||||
};
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_ASCEND_KERNEL_EXECUTOR_H_
|
|
@ -0,0 +1,100 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/ascend/hal/hardware/ascend_utils.h"
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "common/util/error_manager/error_manager.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "runtime/device/ms_device_shape_transfer.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace ascend {
|
||||
constexpr auto kUnknowErrorString = "Unknown error occurred";
|
||||
void ReportErrorMessage() {
|
||||
const string &error_message = ErrorManager::GetInstance().GetErrorMessage();
|
||||
if (!error_message.empty() && error_message.find(kUnknowErrorString) == string::npos) {
|
||||
MS_LOG(ERROR) << "Ascend error occurred, error message:\n" << error_message;
|
||||
}
|
||||
}
|
||||
|
||||
void SetErrorManagerContext() { ErrorManager::GetInstance().GenWorkStreamIdDefault(); }
|
||||
|
||||
void ReportWarningMessage() {
|
||||
const string &warning_message = ErrorManager::GetInstance().GetWarningMessage();
|
||||
if (!warning_message.empty()) {
|
||||
MS_LOG(WARNING) << "Ascend warning message:\n" << warning_message;
|
||||
}
|
||||
}
|
||||
|
||||
bool IsGraphMode() {
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
return context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode;
|
||||
}
|
||||
|
||||
bool IsDynamicShapeGraph(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
|
||||
return std::any_of(node_list.begin(), node_list.end(),
|
||||
[](const AnfNodePtr &node) { return common::AnfAlgo::IsDynamicShape(node); });
|
||||
}
|
||||
|
||||
void AssignOutputNopNodeDeviceAddress(const KernelGraphPtr &graph, const device::DeviceContext *device_context) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto outputs = common::AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem});
|
||||
for (auto output : outputs) {
|
||||
if (!output->isa<CNode>() || !AnfUtils::IsRealKernel(output)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!common::AnfAlgo::IsNopNode(output)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!common::AnfAlgo::IsNeedSkipNopOpAddr(output)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(output);
|
||||
if (input_num != 1) {
|
||||
MS_LOG(WARNING) << "The input number of nop node :" << output->fullname_with_scope() << " is " << input_num
|
||||
<< ", not equal 1";
|
||||
continue;
|
||||
}
|
||||
|
||||
auto real_input_index = AnfAlgo::GetRealInputIndex(output, 0);
|
||||
auto pre_node_out_device_address = AnfAlgo::GetPrevNodeOutputAddr(output, real_input_index);
|
||||
MS_EXCEPTION_IF_NULL(pre_node_out_device_address);
|
||||
auto ptr = pre_node_out_device_address->GetPtr();
|
||||
auto size = pre_node_out_device_address->GetSize();
|
||||
std::string output_format = AnfAlgo::GetOutputFormat(output, 0);
|
||||
auto output_type = AnfAlgo::GetOutputDeviceDataType(output, 0);
|
||||
auto device_address = device_context->device_res_manager_->CreateDeviceAddress(
|
||||
const_cast<void *>(ptr), size, output_format, output_type, trans::GetRuntimePaddingShape(output, 0));
|
||||
device_address->set_is_ptr_persisted(true);
|
||||
device_address->set_host_shape(trans::GetRuntimePaddingShape(output, 0));
|
||||
AnfAlgo::SetOutputAddr(device_address, 0, output.get());
|
||||
common::AnfAlgo::SetNodeAttr(kAttrSkipNopOpAddr, MakeValue(false), output);
|
||||
MS_LOG(INFO) << "Assign device address to output nop node " << output->fullname_with_scope();
|
||||
}
|
||||
}
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_ASCEND_UTILS_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_ASCEND_UTILS_H_
|
||||
|
||||
#include "plugin/device/ascend/hal/hardware/ascend_device_context.h"
|
||||
#include "backend/common/session/kernel_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace ascend {
|
||||
void ReportErrorMessage();
|
||||
void ReportWarningMessage();
|
||||
void SetErrorManagerContext();
|
||||
|
||||
bool IsGraphMode();
|
||||
bool IsDynamicShapeGraph(const FuncGraphPtr &func_graph);
|
||||
|
||||
// Some NOP nodes have be hide in execution order, it doesn't have output device address, this function creates
|
||||
// output device address for these nodes, and the output device address is the same with input device address.
|
||||
void AssignOutputNopNodeDeviceAddress(const KernelGraphPtr &graph, const device::DeviceContext *device_context);
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_ASCEND_UTILS_H_
|
|
@ -58,13 +58,11 @@ void CPUDeviceContext::Initialize() {
|
|||
if (initialized_) {
|
||||
return;
|
||||
}
|
||||
|
||||
mem_manager_ = std::make_shared<CPUMemoryManager>();
|
||||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||
device_res_manager_->Initialize();
|
||||
|
||||
#ifndef ENABLE_SECURITY
|
||||
// Dump json config file if dump is enabled.
|
||||
auto rank_id = GetRankID();
|
||||
auto rank_id = 0;
|
||||
auto &json_parser = DumpJsonParser::GetInstance();
|
||||
json_parser.Parse();
|
||||
json_parser.CopyDumpJsonToDir(rank_id);
|
||||
|
@ -74,7 +72,14 @@ void CPUDeviceContext::Initialize() {
|
|||
initialized_ = true;
|
||||
}
|
||||
|
||||
void CPUDeviceContext::Destroy() {
|
||||
void CPUDeviceContext::Destroy() { device_res_manager_->Destroy(); }
|
||||
|
||||
void CPUDeviceResManager::Initialize() {
|
||||
mem_manager_ = std::make_shared<CPUMemoryManager>();
|
||||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||
}
|
||||
|
||||
void CPUDeviceResManager::Destroy() {
|
||||
// Release memory.
|
||||
if (mem_manager_ != nullptr) {
|
||||
mem_manager_->Finalize();
|
||||
|
@ -82,30 +87,32 @@ void CPUDeviceContext::Destroy() {
|
|||
}
|
||||
}
|
||||
|
||||
void *CPUDeviceContext::AllocateMemory(size_t size) const {
|
||||
void *CPUDeviceResManager::AllocateMemory(size_t size) const {
|
||||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||
return mem_manager_->MallocMemFromMemPool(size, false);
|
||||
}
|
||||
|
||||
void CPUDeviceContext::FreeMemory(void *ptr) const {
|
||||
void CPUDeviceResManager::FreeMemory(void *ptr) const {
|
||||
MS_EXCEPTION_IF_NULL(ptr);
|
||||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||
mem_manager_->FreeMemFromMemPool(ptr);
|
||||
}
|
||||
|
||||
std::vector<void *> CPUDeviceContext::AllocateContinuousMemory(const std::vector<size_t> &size_list) const {
|
||||
std::vector<void *> CPUDeviceResManager::AllocateContinuousMemory(const std::vector<size_t> &size_list) const {
|
||||
return mem_manager_->MallocContinuousMemFromMemPool(size_list);
|
||||
}
|
||||
|
||||
DeviceAddressPtr CPUDeviceContext::CreateDeviceAddress(void *const device_ptr, size_t device_size, const string &format,
|
||||
TypeId type_id, const ShapeVector &shape) const {
|
||||
auto device_address = std::make_shared<CPUDeviceAddress>(
|
||||
device_ptr, device_size, format, type_id, device_context_key_.device_name_, device_context_key_.device_id_);
|
||||
DeviceAddressPtr CPUDeviceResManager::CreateDeviceAddress(void *const device_ptr, size_t device_size,
|
||||
const string &format, TypeId type_id,
|
||||
const ShapeVector &shape) const {
|
||||
auto device_address = std::make_shared<CPUDeviceAddress>(device_ptr, device_size, format, type_id,
|
||||
device_context_->device_context_key().device_name_,
|
||||
device_context_->device_context_key().device_id_);
|
||||
device_address->set_host_shape(shape);
|
||||
return device_address;
|
||||
}
|
||||
|
||||
void CPUDeviceContext::OptimizeGraph(const FuncGraphPtr &graph) const {
|
||||
void CPUKernelExecutor::OptimizeGraph(const FuncGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
|
@ -133,7 +140,7 @@ void CPUDeviceContext::OptimizeGraph(const FuncGraphPtr &graph) const {
|
|||
}
|
||||
}
|
||||
|
||||
void CPUDeviceContext::UpdateKernelRefInfo(const KernelGraphPtr &graph) const {
|
||||
void CPUKernelExecutor::UpdateKernelRefInfo(const KernelGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
const std::vector<CNodePtr> &kernels = graph->execution_order();
|
||||
for (const auto &kernel : kernels) {
|
||||
|
@ -152,7 +159,7 @@ void CPUDeviceContext::UpdateKernelRefInfo(const KernelGraphPtr &graph) const {
|
|||
}
|
||||
}
|
||||
|
||||
void CPUDeviceContext::OptimizeGraphImpl(const KernelGraphPtr &graph) const {
|
||||
void CPUKernelExecutor::OptimizeGraphImpl(const KernelGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
|
@ -215,7 +222,7 @@ void SetKernelInfoBeforeCreateKernel(const std::vector<CNodePtr> &nodes) {
|
|||
}
|
||||
} // namespace
|
||||
|
||||
void CPUDeviceContext::SetOperatorInfo(const KernelGraphPtr &graph) const {
|
||||
void CPUKernelExecutor::SetOperatorInfo(const KernelGraphPtr &graph) const {
|
||||
#ifdef ENABLE_AKG
|
||||
bool do_expand = false;
|
||||
#endif
|
||||
|
@ -252,7 +259,7 @@ void CPUDeviceContext::SetOperatorInfo(const KernelGraphPtr &graph) const {
|
|||
}
|
||||
#endif
|
||||
}
|
||||
void CPUDeviceContext::CreateKernel(const std::vector<CNodePtr> &nodes) const {
|
||||
void CPUKernelExecutor::CreateKernel(const std::vector<CNodePtr> &nodes) const {
|
||||
SetKernelInfoBeforeCreateKernel(nodes);
|
||||
|
||||
kernel::KernelMeta *bin_map = kernel::KernelMeta::GetInstance();
|
||||
|
@ -308,7 +315,7 @@ void CPUDeviceContext::CreateKernel(const std::vector<CNodePtr> &nodes) const {
|
|||
#endif
|
||||
}
|
||||
|
||||
void CPUDeviceContext::PreprocessBeforeRun(const FuncGraphPtr &graph) const {
|
||||
void CPUKernelExecutor::PreprocessBeforeRun(const FuncGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
|
@ -332,9 +339,9 @@ void CPUDeviceContext::PreprocessBeforeRun(const FuncGraphPtr &graph) const {
|
|||
}
|
||||
}
|
||||
|
||||
bool CPUDeviceContext::LaunchKernel(const CNodePtr &kernel, const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) const {
|
||||
bool CPUKernelExecutor::LaunchKernel(const CNodePtr &kernel, const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) const {
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
MS_LOG(DEBUG) << "Launch kernel: " << kernel->fullname_with_scope();
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
||||
|
@ -358,7 +365,7 @@ bool CPUDeviceContext::LaunchKernel(const CNodePtr &kernel, const std::vector<Ad
|
|||
return DoLaunchKernel(kernel_mod, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
bool CPUDeviceContext::LoadCollectiveCommLib() {
|
||||
bool CPUDeviceResManager::LoadCollectiveCommLib() {
|
||||
bool using_mpi = common::UseMPI();
|
||||
if (using_mpi) {
|
||||
std::string mpi_comm_lib_name = "libmpi_collective.so";
|
||||
|
@ -384,9 +391,9 @@ bool CPUDeviceContext::LoadCollectiveCommLib() {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool CPUDeviceContext::LaunchKernelWithProfiling(const CNodePtr &kernel, const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) const {
|
||||
bool CPUKernelExecutor::LaunchKernelWithProfiling(const CNodePtr &kernel, const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) const {
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
|
||||
auto profiler_inst = profiler::cpu::CPUProfiler::GetInstance();
|
||||
|
@ -404,9 +411,9 @@ bool CPUDeviceContext::LaunchKernelWithProfiling(const CNodePtr &kernel, const s
|
|||
return ret;
|
||||
}
|
||||
|
||||
bool CPUDeviceContext::DoLaunchKernel(KernelMod *const kernel_mod, const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) const {
|
||||
bool CPUKernelExecutor::DoLaunchKernel(KernelMod *const kernel_mod, const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) const {
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
return kernel_mod->Launch(inputs, workspace, outputs, nullptr);
|
||||
}
|
||||
|
|
|
@ -27,25 +27,36 @@
|
|||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace cpu {
|
||||
class CPUDeviceContext : public DeviceContext {
|
||||
class CPUDeviceResManager : public DeviceResManager {
|
||||
public:
|
||||
explicit CPUDeviceContext(const DeviceContextKey &device_context_key)
|
||||
: DeviceContext(device_context_key), mem_manager_(nullptr), initialized_(false) {}
|
||||
~CPUDeviceContext() override = default;
|
||||
CPUDeviceResManager() : mem_manager_(nullptr) {}
|
||||
~CPUDeviceResManager() override = default;
|
||||
|
||||
void Initialize() override;
|
||||
|
||||
void Destroy() override;
|
||||
|
||||
// Relevant function to allocate and free device memory of raw ptr.
|
||||
void *AllocateMemory(size_t size) const override;
|
||||
void FreeMemory(void *ptr) const override;
|
||||
|
||||
std::vector<void *> AllocateContinuousMemory(const std::vector<size_t> &size_list) const override;
|
||||
|
||||
DeviceAddressPtr CreateDeviceAddress(void *const device_ptr, size_t device_size, const string &format, TypeId type_id,
|
||||
const ShapeVector &shape = ShapeVector()) const override;
|
||||
|
||||
bool LoadCollectiveCommLib() override;
|
||||
|
||||
protected:
|
||||
// Relevant function to allocate and free device memory of raw ptr.
|
||||
void *AllocateMemory(size_t size) const override;
|
||||
void FreeMemory(void *ptr) const override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<MemoryManager> mem_manager_;
|
||||
};
|
||||
|
||||
class CPUKernelExecutor : public KernelExecutor {
|
||||
public:
|
||||
CPUKernelExecutor() = default;
|
||||
~CPUKernelExecutor() override = default;
|
||||
|
||||
void OptimizeGraph(const FuncGraphPtr &graph) const override;
|
||||
|
||||
void CreateKernel(const std::vector<CNodePtr> &nodes) const override;
|
||||
|
@ -55,11 +66,7 @@ class CPUDeviceContext : public DeviceContext {
|
|||
bool LaunchKernel(const CNodePtr &kernel, const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs) const override;
|
||||
|
||||
bool LoadCollectiveCommLib() override;
|
||||
|
||||
private:
|
||||
DISABLE_COPY_AND_ASSIGN(CPUDeviceContext);
|
||||
|
||||
// Select the matching backend kernels according to the data type and format of input and output for all
|
||||
// execution operators, and set final device data type and format information for backend kernels, device
|
||||
// data type and format which replace original data type and format will use for executing kernels.
|
||||
|
@ -79,7 +86,22 @@ class CPUDeviceContext : public DeviceContext {
|
|||
void UpdateKernelRefInfo(const KernelGraphPtr &graph) const;
|
||||
|
||||
mutable std::mutex launch_mutex_;
|
||||
std::shared_ptr<MemoryManager> mem_manager_;
|
||||
};
|
||||
|
||||
class CPUDeviceContext : public DeviceInterface<CPUKernelExecutor, CPUDeviceResManager> {
|
||||
public:
|
||||
explicit CPUDeviceContext(const DeviceContextKey &device_context_key)
|
||||
: DeviceInterface(device_context_key), initialized_(false) {}
|
||||
~CPUDeviceContext() override = default;
|
||||
|
||||
void Initialize() override;
|
||||
|
||||
void Destroy() override;
|
||||
|
||||
RunMode GetRunMode(const FuncGraphPtr &func_graph) const override { return RunMode::kKernelMode; }
|
||||
|
||||
private:
|
||||
DISABLE_COPY_AND_ASSIGN(CPUDeviceContext);
|
||||
bool initialized_;
|
||||
};
|
||||
} // namespace cpu
|
||||
|
|
|
@ -56,7 +56,7 @@ void GPUBucket::AllocateContinuousMemory(const std::vector<DeviceAddressPtr> &to
|
|||
const auto &device_context =
|
||||
device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name_, device_id_});
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
std::vector<void *> dev_ptr_list = device_context->AllocateContinuousMemory(size_list);
|
||||
std::vector<void *> dev_ptr_list = device_context->device_res_manager_->AllocateContinuousMemory(size_list);
|
||||
if (dev_ptr_list.empty() || dev_ptr_list.size() != to_allocate_address.size()) {
|
||||
MS_LOG(EXCEPTION) << "Allocate continuous memory failed, device ptr list size: " << dev_ptr_list.size()
|
||||
<< ", address list size:" << to_allocate_address.size();
|
||||
|
@ -70,10 +70,10 @@ void GPUBucket::AllocateContinuousMemory(const std::vector<DeviceAddressPtr> &to
|
|||
MS_LOG(EXCEPTION) << "Device size from old device address is larger than new device address, " << old_size
|
||||
<< " vs " << size_list[i];
|
||||
}
|
||||
auto new_dev_addr = device_context->CreateDeviceAddress(dev_ptr_list[i], old_size, old_dev_addr->format(),
|
||||
old_dev_addr->type_id(), old_dev_addr->host_shape());
|
||||
auto new_dev_addr = device_context->device_res_manager_->CreateDeviceAddress(
|
||||
dev_ptr_list[i], old_size, old_dev_addr->format(), old_dev_addr->type_id(), old_dev_addr->host_shape());
|
||||
new_dev_addr->SyncDeviceToDevice(old_dev_addr.get());
|
||||
device_context->FreeMemory(old_dev_addr.get());
|
||||
device_context->device_res_manager_->FreeMemory(old_dev_addr.get());
|
||||
}
|
||||
to_allocate_address[i]->set_ptr(dev_ptr_list[i]);
|
||||
to_allocate_address[i]->SetSize(size_list[i]);
|
||||
|
|
|
@ -37,7 +37,7 @@ BlockQueueStatus_T GpuDataQueueDynamic::Push(std::vector<DataQueueItem> data) {
|
|||
MS_LOG(ERROR) << "Invalid Input: ptr: " << item.data_ptr_ << ", len: " << item.data_len_;
|
||||
return ERROR_INPUT;
|
||||
}
|
||||
void *addr = device_context_->AllocateMemory(item.data_len_);
|
||||
void *addr = device_context_->device_res_manager_->AllocateMemory(item.data_len_);
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(addr, item.data_ptr_, item.data_len_, cudaMemcpyHostToDevice, stream_),
|
||||
"Cuda Memcpy Error");
|
||||
item.device_addr_ = addr;
|
||||
|
|
|
@ -97,7 +97,7 @@ bool GPUDeviceAddress::SyncHostToDevice(size_t size, const void *host_ptr) const
|
|||
device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name_, device_id_});
|
||||
auto gpu_device_context = dynamic_cast<GPUDeviceContext *>(device_context);
|
||||
MS_EXCEPTION_IF_NULL(gpu_device_context);
|
||||
if (!gpu_device_context->BindDeviceToCurrentThread()) {
|
||||
if (!gpu_device_context->device_res_manager_->BindDeviceToCurrentThread()) {
|
||||
MS_LOG(EXCEPTION) << "BindDeviceToCurrentThread failed.";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -48,6 +48,9 @@
|
|||
#ifdef ENABLE_DEBUGGER
|
||||
#include "debug/debugger/debugger.h"
|
||||
#endif
|
||||
#ifndef ENABLE_SECURITY
|
||||
#include "debug/data_dump/dump_json_parser.h"
|
||||
#endif
|
||||
#include "backend/common/pass/optimize_updatestate.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "common/graph_kernel/adapter/expander.h"
|
||||
|
@ -62,26 +65,43 @@ static thread_local bool cur_thread_device_inited{false};
|
|||
|
||||
void GPUDeviceContext::Initialize() {
|
||||
if (initialized_ == true) {
|
||||
if (!BindDeviceToCurrentThread()) {
|
||||
if (!device_res_manager_->BindDeviceToCurrentThread()) {
|
||||
MS_LOG(EXCEPTION) << "BindDeviceToCurrentThread failed.";
|
||||
}
|
||||
GPUMemoryAllocator::GetInstance().CheckMaxDeviceMemory();
|
||||
return;
|
||||
}
|
||||
|
||||
device_res_manager_->Initialize();
|
||||
auto gpu_kernel_executor = dynamic_cast<GPUKernelExecutor *>(kernel_executor_.get());
|
||||
MS_EXCEPTION_IF_NULL(gpu_kernel_executor);
|
||||
gpu_kernel_executor->Initialize();
|
||||
|
||||
#ifndef ENABLE_SECURITY
|
||||
// Dump json config file if dump is enabled.
|
||||
auto rank_id = gpu_kernel_executor->GetRankID();
|
||||
auto &json_parser = DumpJsonParser::GetInstance();
|
||||
json_parser.Parse();
|
||||
json_parser.CopyDumpJsonToDir(rank_id);
|
||||
json_parser.CopyMSCfgJsonToDir(rank_id);
|
||||
#endif
|
||||
initialized_ = true;
|
||||
}
|
||||
|
||||
void GPUDeviceResManager::Initialize() {
|
||||
// Set device id
|
||||
if (CollectiveInitializer::instance().collective_inited()) {
|
||||
DeviceContextKey old_key = device_context_key_;
|
||||
device_context_key_.device_id_ = CollectiveInitializer::instance().local_rank_id();
|
||||
DeviceContextKey old_key = device_context_->device_context_key();
|
||||
device_context_->device_context_key_.device_id_ = CollectiveInitializer::instance().local_rank_id();
|
||||
|
||||
DeviceContextManager::GetInstance().UpdateDeviceContextKey(old_key, device_context_key_);
|
||||
DeviceContextManager::GetInstance().UpdateDeviceContextKey(old_key, device_context_->device_context_key());
|
||||
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_context_key_.device_id_);
|
||||
ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_context_->device_context_key().device_id_);
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Set GPU device id index " << device_context_key_.device_id_;
|
||||
MS_LOG(INFO) << "Set GPU device id index " << device_context_->device_context_key().device_id_;
|
||||
// Set device id and initialize device resource.
|
||||
if (!InitDevice()) {
|
||||
MS_LOG(EXCEPTION) << "GPU InitDevice failed.";
|
||||
|
@ -96,7 +116,8 @@ void GPUDeviceContext::Initialize() {
|
|||
if (CollectiveInitializer::instance().collective_inited()) {
|
||||
auto collective_handle = CollectiveInitializer::instance().collective_handle();
|
||||
if (collective_handle != nullptr) {
|
||||
MS_LOG(INFO) << "Start initializing NCCL communicator for device " << device_context_key_.device_id_;
|
||||
MS_LOG(INFO) << "Start initializing NCCL communicator for device "
|
||||
<< device_context_->device_context_key().device_id_;
|
||||
auto init_nccl_comm_funcptr =
|
||||
reinterpret_cast<InitNCCLComm>(dlsym(const_cast<void *>(collective_handle), "InitNCCLComm"));
|
||||
MS_EXCEPTION_IF_NULL(init_nccl_comm_funcptr);
|
||||
|
@ -104,27 +125,18 @@ void GPUDeviceContext::Initialize() {
|
|||
MS_LOG(INFO) << "End initializing NCCL communicator.";
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef ENABLE_SECURITY
|
||||
// Dump json config file if dump is enabled.
|
||||
auto rank_id = GetRankID();
|
||||
auto &json_parser = DumpJsonParser::GetInstance();
|
||||
json_parser.Parse();
|
||||
json_parser.CopyDumpJsonToDir(rank_id);
|
||||
json_parser.CopyMSCfgJsonToDir(rank_id);
|
||||
#endif
|
||||
initialized_ = true;
|
||||
}
|
||||
|
||||
bool GPUDeviceContext::InitDevice() {
|
||||
bool GPUDeviceResManager::InitDevice() {
|
||||
if (GPUDeviceManager::GetInstance().device_count() <= 0) {
|
||||
MS_LOG(ERROR) << "No GPU device found.";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!GPUDeviceManager::GetInstance().is_device_id_init()) {
|
||||
if (!GPUDeviceManager::GetInstance().set_cur_device_id(device_context_key_.device_id_)) {
|
||||
MS_LOG(ERROR) << "Failed to set current device id: " << SizeToInt(device_context_key_.device_id_);
|
||||
if (!GPUDeviceManager::GetInstance().set_cur_device_id(device_context_->device_context_key().device_id_)) {
|
||||
MS_LOG(ERROR) << "Failed to set current device id: "
|
||||
<< SizeToInt(device_context_->device_context_key().device_id_);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -154,19 +166,7 @@ bool GPUDeviceContext::InitDevice() {
|
|||
return true;
|
||||
}
|
||||
|
||||
void GPUDeviceContext::Destroy() {
|
||||
// Release GPU buffer manager resource
|
||||
#ifdef ENABLE_DEBUGGER
|
||||
auto debugger = Debugger::GetInstance();
|
||||
if (debugger && debugger->debugger_enabled()) {
|
||||
debugger->SetTrainingDone(true);
|
||||
bool ret = debugger->SendMetadata(false);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "Failed to SendMetadata when finalize";
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void GPUDeviceResManager::Destroy() {
|
||||
if (DataQueueMgr::GetInstance().IsInit()) {
|
||||
if (!DataQueueMgr::GetInstance().IsClosed() && !DataQueueMgr::GetInstance().CloseNotify()) {
|
||||
MS_LOG(ERROR) << "Could not close gpu data queue.";
|
||||
|
@ -184,7 +184,23 @@ void GPUDeviceContext::Destroy() {
|
|||
}
|
||||
}
|
||||
|
||||
void *GPUDeviceContext::AllocateMemory(size_t size) const {
|
||||
void GPUDeviceContext::Destroy() {
|
||||
#ifdef ENABLE_DEBUGGER
|
||||
auto debugger = Debugger::GetInstance();
|
||||
if (debugger && debugger->debugger_enabled()) {
|
||||
debugger->SetTrainingDone(true);
|
||||
bool ret = debugger->SendMetadata(false);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "Failed to SendMetadata when finalize";
|
||||
}
|
||||
}
|
||||
#endif
|
||||
auto gpu_kernel_executor = dynamic_cast<GPUKernelExecutor *>(kernel_executor_.get());
|
||||
gpu_kernel_executor->Destroy();
|
||||
device_res_manager_->Destroy();
|
||||
}
|
||||
|
||||
void *GPUDeviceResManager::AllocateMemory(size_t size) const {
|
||||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||
if (!BindDeviceToCurrentThread()) {
|
||||
return nullptr;
|
||||
|
@ -192,13 +208,13 @@ void *GPUDeviceContext::AllocateMemory(size_t size) const {
|
|||
return mem_manager_->MallocMemFromMemPool(size, false);
|
||||
}
|
||||
|
||||
void GPUDeviceContext::FreeMemory(void *ptr) const {
|
||||
void GPUDeviceResManager::FreeMemory(void *ptr) const {
|
||||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||
MS_EXCEPTION_IF_NULL(ptr);
|
||||
mem_manager_->FreeMemFromMemPool(ptr);
|
||||
}
|
||||
|
||||
std::vector<void *> GPUDeviceContext::AllocateContinuousMemory(const std::vector<size_t> &size_list) const {
|
||||
std::vector<void *> GPUDeviceResManager::AllocateContinuousMemory(const std::vector<size_t> &size_list) const {
|
||||
if (!BindDeviceToCurrentThread()) {
|
||||
std::vector<void *> ptr_list;
|
||||
return ptr_list;
|
||||
|
@ -206,15 +222,17 @@ std::vector<void *> GPUDeviceContext::AllocateContinuousMemory(const std::vector
|
|||
return mem_manager_->MallocContinuousMemFromMemPool(size_list);
|
||||
}
|
||||
|
||||
DeviceAddressPtr GPUDeviceContext::CreateDeviceAddress(void *const device_ptr, size_t device_size, const string &format,
|
||||
TypeId type_id, const ShapeVector &shape) const {
|
||||
auto device_address = std::make_shared<GPUDeviceAddress>(
|
||||
device_ptr, device_size, format, type_id, device_context_key_.device_name_, device_context_key_.device_id_);
|
||||
DeviceAddressPtr GPUDeviceResManager::CreateDeviceAddress(void *const device_ptr, size_t device_size,
|
||||
const string &format, TypeId type_id,
|
||||
const ShapeVector &shape) const {
|
||||
auto device_address = std::make_shared<GPUDeviceAddress>(device_ptr, device_size, format, type_id,
|
||||
device_context_->device_context_key().device_name_,
|
||||
device_context_->device_context_key().device_id_);
|
||||
device_address->set_host_shape(shape);
|
||||
return device_address;
|
||||
}
|
||||
|
||||
void GPUDeviceContext::PreprocessBeforeRun(const FuncGraphPtr &graph) const {
|
||||
void GPUKernelExecutor::PreprocessBeforeRun(const FuncGraphPtr &graph) const {
|
||||
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto profiler_inst = profiler::gpu::GPUProfiler::GetInstance();
|
||||
|
@ -232,13 +250,13 @@ void GPUDeviceContext::PreprocessBeforeRun(const FuncGraphPtr &graph) const {
|
|||
}
|
||||
}
|
||||
|
||||
void GPUDeviceContext::OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const {
|
||||
void GPUKernelExecutor::OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
// Operator fusion optimization.
|
||||
FuseOperators(graph);
|
||||
}
|
||||
|
||||
void GPUDeviceContext::OptimizeGraphWithDeviceInfo(const KernelGraphPtr &graph) const {
|
||||
void GPUKernelExecutor::OptimizeGraphWithDeviceInfo(const KernelGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
|
@ -272,7 +290,7 @@ void GPUDeviceContext::OptimizeGraphWithDeviceInfo(const KernelGraphPtr &graph)
|
|||
graph->SetExecOrderByDefault();
|
||||
}
|
||||
|
||||
void GPUDeviceContext::FuseOperators(const KernelGraphPtr &graph) const {
|
||||
void GPUKernelExecutor::FuseOperators(const KernelGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
|
@ -419,7 +437,14 @@ std::lock_guard<std::mutex> LockLaunchKernel(const void *stream) {
|
|||
}
|
||||
} // namespace
|
||||
|
||||
void GPUDeviceContext::OptimizeGraph(const FuncGraphPtr &graph) const {
|
||||
void GPUKernelExecutor::Initialize() {
|
||||
res_manager_ = dynamic_cast<GPUDeviceResManager *>(device_context_->device_res_manager_.get());
|
||||
MS_EXCEPTION_IF_NULL(res_manager_);
|
||||
}
|
||||
|
||||
void GPUKernelExecutor::Destroy() { res_manager_ = nullptr; }
|
||||
|
||||
void GPUKernelExecutor::OptimizeGraph(const FuncGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
|
@ -458,7 +483,7 @@ void GPUDeviceContext::OptimizeGraph(const FuncGraphPtr &graph) const {
|
|||
}
|
||||
}
|
||||
|
||||
void GPUDeviceContext::UpdateKernelRefInfo(const KernelGraphPtr &graph) const {
|
||||
void GPUKernelExecutor::UpdateKernelRefInfo(const KernelGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
const std::vector<CNodePtr> &kernels = graph->execution_order();
|
||||
for (const auto &kernel : kernels) {
|
||||
|
@ -478,7 +503,7 @@ void GPUDeviceContext::UpdateKernelRefInfo(const KernelGraphPtr &graph) const {
|
|||
}
|
||||
}
|
||||
|
||||
void GPUDeviceContext::SetOperatorInfo(const KernelGraphPtr &graph) const {
|
||||
void GPUKernelExecutor::SetOperatorInfo(const KernelGraphPtr &graph) const {
|
||||
bool do_expand = false;
|
||||
auto &node_list = graph->execution_order();
|
||||
for (auto &node : node_list) {
|
||||
|
@ -504,16 +529,16 @@ void GPUDeviceContext::SetOperatorInfo(const KernelGraphPtr &graph) const {
|
|||
}
|
||||
}
|
||||
|
||||
void GPUDeviceContext::CreateKernel(const std::vector<CNodePtr> &nodes) const {
|
||||
void GPUKernelExecutor::CreateKernel(const std::vector<CNodePtr> &nodes) const {
|
||||
SetKernelInfoBeforeCreateKernel(nodes);
|
||||
CreateGPUKernel(nodes);
|
||||
}
|
||||
|
||||
bool GPUDeviceContext::LaunchKernel(const CNodePtr &kernel, const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) const {
|
||||
bool GPUKernelExecutor::LaunchKernel(const CNodePtr &kernel, const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) const {
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
if (!BindDeviceToCurrentThread()) {
|
||||
if (!res_manager_->BindDeviceToCurrentThread()) {
|
||||
return false;
|
||||
}
|
||||
bool ret = true;
|
||||
|
@ -546,16 +571,16 @@ bool GPUDeviceContext::LaunchKernel(const CNodePtr &kernel, const std::vector<Ad
|
|||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if ((ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) &&
|
||||
ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE) && !SyncStream()) {
|
||||
ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE) && !res_manager_->SyncStream()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
#ifndef ENABLE_SECURITY
|
||||
bool GPUDeviceContext::LaunchKernelWithProfiling(const CNodePtr &kernel, const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream) const {
|
||||
bool GPUKernelExecutor::LaunchKernelWithProfiling(const CNodePtr &kernel, const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream) const {
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
MS_EXCEPTION_IF_NULL(stream);
|
||||
|
||||
|
@ -571,7 +596,7 @@ bool GPUDeviceContext::LaunchKernelWithProfiling(const CNodePtr &kernel, const s
|
|||
profiler_inst->SetStepTraceOpName(profiling_trace);
|
||||
}
|
||||
|
||||
profiler_inst->OpDataProducerBegin(kernel->fullname_with_scope(), streams_.front());
|
||||
profiler_inst->OpDataProducerBegin(kernel->fullname_with_scope(), res_manager_->streams_.front());
|
||||
bool ret = DoLaunchKernel(kernel, inputs, workspace, outputs, stream);
|
||||
profiler_inst->OpDataProducerEnd();
|
||||
profiler_inst->RecordFrameWorkInfo(kernel);
|
||||
|
@ -581,14 +606,14 @@ bool GPUDeviceContext::LaunchKernelWithProfiling(const CNodePtr &kernel, const s
|
|||
<< (op_launch_start_end_time.second - op_launch_start_end_time.first) / kBasicTimeTransferUnit;
|
||||
|
||||
if (profiler_inst->GetSyncEnableFlag()) {
|
||||
CHECK_RET_WITH_RETURN_ERROR(SyncStream(), "Profiler SyncStream failed.");
|
||||
CHECK_RET_WITH_RETURN_ERROR(res_manager_->SyncStream(), "Profiler SyncStream failed.");
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
#endif
|
||||
bool GPUDeviceContext::DoLaunchKernel(const CNodePtr &kernel, const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs,
|
||||
void *stream) const {
|
||||
bool GPUKernelExecutor::DoLaunchKernel(const CNodePtr &kernel, const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs,
|
||||
void *stream) const {
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
MS_EXCEPTION_IF_NULL(stream);
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
||||
|
@ -596,24 +621,24 @@ bool GPUDeviceContext::DoLaunchKernel(const CNodePtr &kernel, const std::vector<
|
|||
return kernel_mod->Launch(inputs, workspace, outputs, stream);
|
||||
}
|
||||
|
||||
void *GPUDeviceContext::GetLaunchKernelStream(const CNodePtr &kernel) const {
|
||||
void *GPUKernelExecutor::GetLaunchKernelStream(const CNodePtr &kernel) const {
|
||||
void *stream = nullptr;
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrStream, kernel)) {
|
||||
auto stream_id = common::AnfAlgo::GetNodeAttr<size_t>(kernel, kAttrStream);
|
||||
auto iter = stream_ids_.find(stream_id);
|
||||
if (iter == stream_ids_.end()) {
|
||||
auto iter = res_manager_->stream_ids_.find(stream_id);
|
||||
if (iter == res_manager_->stream_ids_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Can not find stream for stream id: " << stream_id;
|
||||
}
|
||||
stream = iter->second;
|
||||
} else {
|
||||
stream = streams_.front();
|
||||
stream = res_manager_->streams_.front();
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(stream);
|
||||
return stream;
|
||||
}
|
||||
|
||||
bool GPUDeviceContext::SyncStream(size_t stream_id) const {
|
||||
bool GPUDeviceResManager::SyncStream(size_t stream_id) const {
|
||||
void *stream = nullptr;
|
||||
auto iter = stream_ids_.find(stream_id);
|
||||
if (iter != stream_ids_.end()) {
|
||||
|
@ -637,7 +662,7 @@ bool GPUDeviceContext::SyncStream(size_t stream_id) const {
|
|||
return result;
|
||||
}
|
||||
|
||||
bool GPUDeviceContext::CreateStream(void **stream) const {
|
||||
bool GPUDeviceResManager::CreateStream(void **stream) const {
|
||||
MS_EXCEPTION_IF_NULL(stream);
|
||||
if (!CudaDriver::CreateStream(stream)) {
|
||||
MS_LOG(ERROR) << "Failed to create CUDA stream.";
|
||||
|
@ -646,7 +671,7 @@ bool GPUDeviceContext::CreateStream(void **stream) const {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool GPUDeviceContext::DestroyStream(void *stream) const {
|
||||
bool GPUDeviceResManager::DestroyStream(void *stream) const {
|
||||
MS_EXCEPTION_IF_NULL(stream);
|
||||
if (!CudaDriver::DestroyStream(stream)) {
|
||||
MS_LOG(ERROR) << "Failed to destroy CUDA stream.";
|
||||
|
@ -655,7 +680,7 @@ bool GPUDeviceContext::DestroyStream(void *stream) const {
|
|||
return true;
|
||||
}
|
||||
|
||||
uint32_t GPUDeviceContext::GetRankID() const {
|
||||
uint32_t GPUKernelExecutor::GetRankID() const {
|
||||
bool collective_inited = CollectiveInitializer::instance().collective_inited();
|
||||
uint32_t rank_id = 0;
|
||||
if (collective_inited) {
|
||||
|
@ -666,20 +691,21 @@ uint32_t GPUDeviceContext::GetRankID() const {
|
|||
return rank_id;
|
||||
}
|
||||
|
||||
std::shared_ptr<Bucket> GPUDeviceContext::CreateBucket(uint32_t bucket_id, uint32_t bucket_size) const {
|
||||
std::shared_ptr<Bucket> GPUKernelExecutor::CreateBucket(uint32_t bucket_id, uint32_t bucket_size) const {
|
||||
auto bucket = std::make_shared<GPUBucket>(bucket_id, bucket_size);
|
||||
MS_EXCEPTION_IF_NULL(bucket);
|
||||
// One computation stream, one communication stream.
|
||||
const size_t min_num_of_stream = 2;
|
||||
if (min_num_of_stream > streams_.size()) {
|
||||
MS_LOG(EXCEPTION) << "The total stream num: " << streams_.size() << " is less than: " << min_num_of_stream;
|
||||
if (min_num_of_stream > res_manager_->streams_.size()) {
|
||||
MS_LOG(EXCEPTION) << "The total stream num: " << res_manager_->streams_.size()
|
||||
<< " is less than: " << min_num_of_stream;
|
||||
}
|
||||
|
||||
bucket->Init({streams_[0]}, {streams_[1]});
|
||||
bucket->Init({res_manager_->streams_[0]}, {res_manager_->streams_[1]});
|
||||
return bucket;
|
||||
}
|
||||
|
||||
bool GPUDeviceContext::LoadCollectiveCommLib() {
|
||||
bool GPUDeviceResManager::LoadCollectiveCommLib() {
|
||||
#ifdef ENABLE_MPI
|
||||
std::string nvidia_comm_lib_name = "libnvidia_collective.so";
|
||||
auto loader = std::make_shared<CollectiveCommLibLoader>(nvidia_comm_lib_name);
|
||||
|
@ -700,13 +726,13 @@ bool GPUDeviceContext::LoadCollectiveCommLib() {
|
|||
#endif
|
||||
}
|
||||
|
||||
bool GPUDeviceContext::BindDeviceToCurrentThread() const {
|
||||
bool GPUDeviceResManager::BindDeviceToCurrentThread() const {
|
||||
if (cur_thread_device_inited) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (!CudaDriver::SetDevice(UintToInt(device_context_key_.device_id_))) {
|
||||
MS_LOG(ERROR) << "Failed to set device id: " << device_context_key_.device_id_;
|
||||
if (!CudaDriver::SetDevice(UintToInt(device_context_->device_context_key().device_id_))) {
|
||||
MS_LOG(ERROR) << "Failed to set device id: " << device_context_->device_context_key().device_id_;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
@ -27,11 +27,11 @@
|
|||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace gpu {
|
||||
class GPUDeviceContext : public DeviceContext {
|
||||
class GPUKernelExecutor;
|
||||
class GPUDeviceResManager : public DeviceResManager {
|
||||
public:
|
||||
explicit GPUDeviceContext(const DeviceContextKey &device_context_key)
|
||||
: DeviceContext(device_context_key), mem_manager_(nullptr), initialized_(false) {}
|
||||
~GPUDeviceContext() override = default;
|
||||
GPUDeviceResManager() : mem_manager_(nullptr) {}
|
||||
~GPUDeviceResManager() override = default;
|
||||
|
||||
// Set device id and initialize device resource, such as stream, cudnn and cublas handle.
|
||||
void Initialize() override;
|
||||
|
@ -41,38 +41,57 @@ class GPUDeviceContext : public DeviceContext {
|
|||
|
||||
bool BindDeviceToCurrentThread() const override;
|
||||
|
||||
// Relevant function to allocate and free device memory of raw ptr.
|
||||
void *AllocateMemory(size_t size) const override;
|
||||
void FreeMemory(void *ptr) const override;
|
||||
std::vector<void *> AllocateContinuousMemory(const std::vector<size_t> &size_list) const override;
|
||||
|
||||
DeviceAddressPtr CreateDeviceAddress(void *const device_ptr, size_t device_size, const string &format, TypeId type_id,
|
||||
const ShapeVector &shape = ShapeVector()) const override;
|
||||
|
||||
bool SyncStream(size_t stream_id = 0) const override;
|
||||
|
||||
bool LoadCollectiveCommLib() override;
|
||||
|
||||
protected:
|
||||
// Relevant function to allocate and free device memory of raw ptr.
|
||||
void *AllocateMemory(size_t size) const override;
|
||||
void FreeMemory(void *ptr) const override;
|
||||
|
||||
// Really create a cuda stream.
|
||||
bool CreateStream(void **stream) const override;
|
||||
// Really destroy a cuda stream.
|
||||
bool DestroyStream(void *stream) const override;
|
||||
|
||||
private:
|
||||
friend class GPUKernelExecutor;
|
||||
bool InitDevice();
|
||||
std::shared_ptr<MemoryManager> mem_manager_;
|
||||
std::vector<void *> streams_;
|
||||
};
|
||||
|
||||
class GPUKernelExecutor : public DeprecatedKernelExecutor {
|
||||
public:
|
||||
GPUKernelExecutor() = default;
|
||||
~GPUKernelExecutor() override = default;
|
||||
|
||||
void Initialize();
|
||||
void Destroy();
|
||||
|
||||
// Optimize the kernel graph for graph mode.
|
||||
void OptimizeGraph(const FuncGraphPtr &graph) const override;
|
||||
|
||||
void CreateKernel(const std::vector<CNodePtr> &nodes) const override;
|
||||
|
||||
void PreprocessBeforeRun(const FuncGraphPtr &graph) const override;
|
||||
|
||||
bool LaunchKernel(const CNodePtr &kernel, const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs) const override;
|
||||
|
||||
bool SyncStream(size_t stream_id = 0) const override;
|
||||
|
||||
uint32_t GetRankID() const override;
|
||||
|
||||
// Create bucket for every allreduce operator. Bucket is used in PyNative distributed training mode, one bucket
|
||||
// handles all resource to launch and sync allreduce operator.
|
||||
std::shared_ptr<Bucket> CreateBucket(uint32_t bucket_id, uint32_t bucket_size) const override;
|
||||
|
||||
bool LoadCollectiveCommLib() override;
|
||||
|
||||
void PreprocessBeforeRun(const FuncGraphPtr &graph) const override;
|
||||
|
||||
private:
|
||||
DISABLE_COPY_AND_ASSIGN(GPUDeviceContext);
|
||||
bool InitDevice();
|
||||
|
||||
// Select the matching backend kernels according to the data type and format of input and output for all
|
||||
// execution operators, and set final device data type and format information for backend kernels, device
|
||||
// data type and format which replace original data type and format will use for executing kernels.
|
||||
|
@ -104,14 +123,28 @@ class GPUDeviceContext : public DeviceContext {
|
|||
// default stream.
|
||||
void *GetLaunchKernelStream(const CNodePtr &kernel) const;
|
||||
|
||||
// Really create a cuda stream.
|
||||
bool CreateStream(void **stream) const override;
|
||||
// The cublas handle is not thread safety specifically, it is not recommended that multiple threads access the same
|
||||
// cublas handle at the same time, so need the launch mutex when multiple threads launch the cublas kernels.
|
||||
mutable std::mutex launch_mutex_;
|
||||
GPUDeviceResManager *res_manager_{nullptr};
|
||||
};
|
||||
|
||||
// Really destroy a cuda stream.
|
||||
bool DestroyStream(void *stream) const override;
|
||||
class GPUDeviceContext : public DeviceInterface<GPUKernelExecutor, GPUDeviceResManager> {
|
||||
public:
|
||||
explicit GPUDeviceContext(const DeviceContextKey &device_context_key)
|
||||
: DeviceInterface(device_context_key), initialized_(false) {}
|
||||
~GPUDeviceContext() override = default;
|
||||
|
||||
std::shared_ptr<MemoryManager> mem_manager_;
|
||||
std::vector<void *> streams_;
|
||||
// Set device id and initialize device resource, such as stream, cudnn and cublas handle.
|
||||
void Initialize() override;
|
||||
|
||||
// Release device memory, stream, cudnn and cublas handle, etc.
|
||||
void Destroy() override;
|
||||
|
||||
RunMode GetRunMode(const FuncGraphPtr &func_graph) const override { return RunMode::kKernelMode; }
|
||||
|
||||
private:
|
||||
DISABLE_COPY_AND_ASSIGN(GPUDeviceContext);
|
||||
bool initialized_;
|
||||
};
|
||||
} // namespace gpu
|
||||
|
|
|
@ -61,7 +61,7 @@ bool StreamSynchronizer::SyncStream(const std::string &device_name, uint32_t tim
|
|||
// If disable recovery or timeout==0, sync stream directly to improve performance.
|
||||
if (!RecoveryContext::GetInstance()->enable_recovery() || timeout == 0) {
|
||||
device_context->Initialize();
|
||||
return device_context->SyncStream();
|
||||
return device_context->device_res_manager_->SyncStream();
|
||||
}
|
||||
|
||||
std::unique_lock<std::mutex> lock(task_mutex_);
|
||||
|
@ -100,7 +100,7 @@ void StreamSynchronizer::DoSyncStreamTask() {
|
|||
|
||||
device_context_->Initialize();
|
||||
// Really sync stream.
|
||||
sync_stream_ret_ = device_context_->SyncStream();
|
||||
sync_stream_ret_ = device_context_->device_res_manager_->SyncStream();
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(task_mutex_);
|
||||
|
|
|
@ -234,9 +234,9 @@ void FreeMemory(DeviceTensor *const device_tensor, const DeviceContext *device_c
|
|||
const auto &new_device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
|
||||
{device_tensor->device_name(), device_tensor->device_id()});
|
||||
MS_EXCEPTION_IF_NULL(new_device_context);
|
||||
new_device_context->FreeMemory(device_tensor);
|
||||
new_device_context->device_res_manager_->FreeMemory(device_tensor);
|
||||
} else {
|
||||
device_context->FreeMemory(device_tensor);
|
||||
device_context->device_res_manager_->FreeMemory(device_tensor);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -399,7 +399,7 @@ void ControlActor::UpdateOutputData(OpData<DeviceTensor> *const output_data, con
|
|||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
device::DynamicMemAllocatorDebugInfo::SetDebugInfo(GetAID().Name(), device::AllocatorType::kOther, 0);
|
||||
if ((device_tensor->GetPtr() == nullptr) &&
|
||||
(!device_context->AllocateMemory(device_tensor.get(), device_tensor->GetSize()))) {
|
||||
(!device_context->device_res_manager_->AllocateMemory(device_tensor.get(), device_tensor->GetSize()))) {
|
||||
SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(GraphExecutionStrategy::kPipeline, *context, *device_context,
|
||||
formal_parameter.first->DebugString(), device_tensor->GetSize());
|
||||
}
|
||||
|
|
|
@ -138,7 +138,7 @@ void ExitActor::IncreaseDynamicRefCounts(OpContext<DeviceTensor> *const context)
|
|||
for (size_t i = 0; i < input_device_tensors_.size(); ++i) {
|
||||
if ((input_device_tensors_[i] != nullptr) && (input_device_tensors_[i]->dynamic_ref_count() == 0)) {
|
||||
MS_LOG(WARNING) << GetAID().Name() << " input index:" << i << " has no user and free the memory.";
|
||||
device_contexts_[i]->FreeMemory(input_device_tensors_[i]);
|
||||
device_contexts_[i]->device_res_manager_->FreeMemory(input_device_tensors_[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -182,9 +182,9 @@ void ExitActor::CopyDeviceAddress(OpContext<DeviceTensor> *const context) {
|
|||
(void)std::transform(shape_tmp.begin(), shape_tmp.end(), std::back_inserter(host_shape), IntToSize);
|
||||
}
|
||||
// Create the new device tensor to take over the input_device_tensors which are the outputs of kernel graphs.
|
||||
auto new_device_tensor =
|
||||
device_contexts_[i]->CreateDeviceAddress(nullptr, input_device_tensor->GetSize(), input_device_tensor->format(),
|
||||
input_device_tensor->type_id(), host_shape);
|
||||
auto new_device_tensor = device_contexts_[i]->device_res_manager_->CreateDeviceAddress(
|
||||
nullptr, input_device_tensor->GetSize(), input_device_tensor->format(), input_device_tensor->type_id(),
|
||||
host_shape);
|
||||
MS_EXCEPTION_IF_NULL(new_device_tensor);
|
||||
(void)created_device_tensors_.emplace_back(new_device_tensor);
|
||||
(void)new_device_tensors.emplace_back(new_device_tensor.get());
|
||||
|
@ -198,7 +198,8 @@ void ExitActor::CopyDeviceAddress(OpContext<DeviceTensor> *const context) {
|
|||
// If the address ptr can't be changed, then alloc the new device memory and copy the data.
|
||||
if (input_device_tensor->is_ptr_persisted()) {
|
||||
device::DynamicMemAllocatorDebugInfo::SetDebugInfo(GetAID().Name(), device::AllocatorType::kOther);
|
||||
if (!device_contexts_[i]->AllocateMemory(new_device_tensor.get(), new_device_tensor->GetSize())) {
|
||||
if (!device_contexts_[i]->device_res_manager_->AllocateMemory(new_device_tensor.get(),
|
||||
new_device_tensor->GetSize())) {
|
||||
SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(GraphExecutionStrategy::kPipeline, *context, *device_contexts_[i],
|
||||
GetAID().Name(), new_device_tensor->GetSize());
|
||||
}
|
||||
|
|
|
@ -30,7 +30,7 @@ void CustomActor::Run(OpContext<DeviceTensor> *const ctx) {
|
|||
// Launch custom func
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto custom_func = AnfUtils::GetCustomFunc(node);
|
||||
if (!device_contexts_[0]->BindDeviceToCurrentThread()) {
|
||||
if (!device_contexts_[0]->device_res_manager_->BindDeviceToCurrentThread()) {
|
||||
std::string error_info = "BindDevice to current thread failed: " + node->fullname_with_scope();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*ctx), error_info);
|
||||
}
|
||||
|
|
|
@ -43,7 +43,7 @@ void SyncTensorData(const TensorPtr &host_tensor, const DeviceTensorPtr &device_
|
|||
auto allocator_type = node->isa<ValueNode>() ? device::AllocatorType::kConstantValue : device::AllocatorType::kWeight;
|
||||
device::DynamicMemAllocatorDebugInfo::SetDebugInfo(node->fullname_with_scope(), allocator_type, 0);
|
||||
if ((device_tensor->GetPtr() == nullptr) &&
|
||||
(!device_context->AllocateMemory(device_tensor.get(), device_tensor->GetSize()))) {
|
||||
(!device_context->device_res_manager_->AllocateMemory(device_tensor.get(), device_tensor->GetSize()))) {
|
||||
SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(strategy, *context, *device_context, node->fullname_with_scope(),
|
||||
device_tensor->GetSize());
|
||||
}
|
||||
|
@ -145,7 +145,7 @@ void PrepareDataForValue(const ValuePtr &value, const KernelWithIndex &node_with
|
|||
|
||||
device::DynamicMemAllocatorDebugInfo::SetDebugInfo(node->fullname_with_scope(), device::AllocatorType::kConstantValue,
|
||||
0);
|
||||
if (!device_context->AllocateMemory(device_tensor.get(), device_tensor->GetSize())) {
|
||||
if (!device_context->device_res_manager_->AllocateMemory(device_tensor.get(), device_tensor->GetSize())) {
|
||||
SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(GraphExecutionStrategy::kPipeline, *context, *device_context,
|
||||
node->fullname_with_scope(), device_tensor->GetSize());
|
||||
}
|
||||
|
@ -556,9 +556,9 @@ void DataPrepareActor::PrepareDataForStepMode(const std::vector<std::vector<Tens
|
|||
output_type_id = common::AnfAlgo::GetOutputInferDataType(input_node, 0);
|
||||
}
|
||||
size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(input_node, 0);
|
||||
auto device_address =
|
||||
device_context->CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(input_node, 0),
|
||||
output_type_id, trans::GetRuntimePaddingShape(input_node, 0));
|
||||
auto device_address = device_context->device_res_manager_->CreateDeviceAddress(
|
||||
nullptr, tensor_size, AnfAlgo::GetOutputFormat(input_node, 0), output_type_id,
|
||||
trans::GetRuntimePaddingShape(input_node, 0));
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
AnfAlgo::SetOutputAddr(device_address, 0, input_node.get());
|
||||
device_address->SetNodeIndex(input_node, 0);
|
||||
|
@ -652,7 +652,7 @@ void DataPrepareActor::PrepareDataForControlValueNode(const KernelWithIndex &nod
|
|||
UpdateRefCount(device_tensor.get(), true);
|
||||
|
||||
device::DynamicMemAllocatorDebugInfo::SetDebugInfo(node->DebugString(), device::AllocatorType::kConstantValue, 0);
|
||||
if (!device_context->AllocateMemory(device_tensor.get(), device_tensor->GetSize())) {
|
||||
if (!device_context->device_res_manager_->AllocateMemory(device_tensor.get(), device_tensor->GetSize())) {
|
||||
SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(real_strategy_, *context, *device_context,
|
||||
node->fullname_with_scope(), device_tensor->GetSize());
|
||||
}
|
||||
|
@ -695,7 +695,7 @@ void DataPrepareActor::PrepareDataForValueNode(const ValueNodePtr &node, const A
|
|||
|
||||
device::DynamicMemAllocatorDebugInfo::SetDebugInfo(node->fullname_with_scope(),
|
||||
device::AllocatorType::kConstantValue, 0);
|
||||
if (!device_context->AllocateMemory(device_tensor.get(), device_tensor->GetSize())) {
|
||||
if (!device_context->device_res_manager_->AllocateMemory(device_tensor.get(), device_tensor->GetSize())) {
|
||||
SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(real_strategy_, *context, *device_context,
|
||||
node->fullname_with_scope(), device_tensor->GetSize());
|
||||
}
|
||||
|
@ -732,7 +732,8 @@ void DataPrepareActor::CopyDataFromDeviceTensorStore(const AnfNodePtr &front_nod
|
|||
auto type = backend_node->isa<ValueNode>() ? device::AllocatorType::kConstantValue : device::AllocatorType::kWeight;
|
||||
device::DynamicMemAllocatorDebugInfo::SetDebugInfo(backend_node->fullname_with_scope(), type, 0);
|
||||
if ((another_device_tensor->GetPtr() == nullptr) &&
|
||||
(!another_device_context->AllocateMemory(another_device_tensor.get(), another_device_tensor->GetSize()))) {
|
||||
(!another_device_context->device_res_manager_->AllocateMemory(another_device_tensor.get(),
|
||||
another_device_tensor->GetSize()))) {
|
||||
SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(real_strategy_, *context, *another_device_context,
|
||||
backend_node->fullname_with_scope(),
|
||||
another_device_tensor->GetSize());
|
||||
|
@ -769,9 +770,9 @@ void DataPrepareActor::PrepareDataForWeightNode(const AnfNodePtr &backend_node,
|
|||
// The step mode can't reuse the device tensor, because other actors may use the device tensor in step mode.
|
||||
if ((strategy_ == GraphExecutionStrategy::kStep) ||
|
||||
(device_tensor->GetDeviceType() != device_context->GetDeviceType())) {
|
||||
host_tensor_address =
|
||||
device_context->CreateDeviceAddress(nullptr, device_tensor->GetSize(), device_tensor->format(),
|
||||
device_tensor->type_id(), device_tensor->host_shape());
|
||||
host_tensor_address = device_context->device_res_manager_->CreateDeviceAddress(
|
||||
nullptr, device_tensor->GetSize(), device_tensor->format(), device_tensor->type_id(),
|
||||
device_tensor->host_shape());
|
||||
host_tensor_address->set_from_persistent_mem(tensor->is_parameter());
|
||||
} else {
|
||||
host_tensor_address = device_tensor;
|
||||
|
|
|
@ -141,8 +141,8 @@ void DeviceQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *co
|
|||
|
||||
// Copy data from device queue by data kernel launching.
|
||||
try {
|
||||
auto ret = device_contexts_[0]->LaunchKernel(data_kernel_, launch_info_.inputs_, launch_info_.workspaces_,
|
||||
launch_info_.outputs_);
|
||||
auto ret = device_contexts_[0]->kernel_executor_->LaunchKernel(data_kernel_, launch_info_.inputs_,
|
||||
launch_info_.workspaces_, launch_info_.outputs_);
|
||||
if (!ret) {
|
||||
std::string error_info = "Launch kernel failed: " + data_kernel_->fullname_with_scope();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
|
@ -313,8 +313,8 @@ void HostQueueDataSourceActor::ReleaseDataNodeAddress() {
|
|||
auto device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
|
||||
{old_address->device_name(), old_address->device_id()});
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
auto new_address = device_context->CreateDeviceAddress(nullptr, old_address->GetSize(), old_address->format(),
|
||||
old_address->type_id(), old_address->host_shape());
|
||||
auto new_address = device_context->device_res_manager_->CreateDeviceAddress(
|
||||
nullptr, old_address->GetSize(), old_address->format(), old_address->type_id(), old_address->host_shape());
|
||||
MS_EXCEPTION_IF_NULL(new_address);
|
||||
new_address->set_original_ref_count(old_address->original_ref_count());
|
||||
new_address->ResetRefCount();
|
||||
|
|
|
@ -139,12 +139,12 @@ bool MemcpyHostToDeviceAsync(void *dst, const void *src, size_t size, const Devi
|
|||
void *device_ptr = dst;
|
||||
const void *host_ptr = src;
|
||||
|
||||
auto device_address =
|
||||
device_context->CreateDeviceAddress(device_ptr, size, kOpFormat_DEFAULT, kTypeUnknown, ShapeVector());
|
||||
auto device_address = device_context->device_res_manager_->CreateDeviceAddress(device_ptr, size, kOpFormat_DEFAULT,
|
||||
kTypeUnknown, ShapeVector());
|
||||
MS_ERROR_IF_NULL(device_address);
|
||||
RETURN_IF_FALSE_WITH_LOG(
|
||||
device_address->AsyncHostToDevice({}, size, kTypeUnknown, host_ptr, device_context->GetStream(stream_id)),
|
||||
"Async memcpy host to device failed.");
|
||||
RETURN_IF_FALSE_WITH_LOG(device_address->AsyncHostToDevice({}, size, kTypeUnknown, host_ptr,
|
||||
device_context->device_res_manager_->GetStream(stream_id)),
|
||||
"Async memcpy host to device failed.");
|
||||
|
||||
return true;
|
||||
}
|
||||
|
@ -159,12 +159,12 @@ bool MemcpyDeviceToHostAsync(void *dst, const void *src, size_t size, const Devi
|
|||
void *device_ptr = const_cast<void *>(src);
|
||||
void *host_ptr = dst;
|
||||
|
||||
auto device_address =
|
||||
device_context->CreateDeviceAddress(device_ptr, size, kOpFormat_DEFAULT, kTypeUnknown, ShapeVector());
|
||||
auto device_address = device_context->device_res_manager_->CreateDeviceAddress(device_ptr, size, kOpFormat_DEFAULT,
|
||||
kTypeUnknown, ShapeVector());
|
||||
MS_ERROR_IF_NULL(device_address);
|
||||
RETURN_IF_FALSE_WITH_LOG(
|
||||
device_address->AsyncDeviceToHost({}, size, kTypeUnknown, host_ptr, device_context->GetStream(stream_id)),
|
||||
"Async memcpy device to host failed.");
|
||||
RETURN_IF_FALSE_WITH_LOG(device_address->AsyncDeviceToHost({}, size, kTypeUnknown, host_ptr,
|
||||
device_context->device_res_manager_->GetStream(stream_id)),
|
||||
"Async memcpy device to host failed.");
|
||||
|
||||
return true;
|
||||
}
|
||||
|
@ -172,7 +172,7 @@ bool MemcpyDeviceToHostAsync(void *dst, const void *src, size_t size, const Devi
|
|||
|
||||
void EmbeddingCachePrefetchActor::Initialize() {
|
||||
MS_EXCEPTION_IF_NULL(device_context_);
|
||||
if (!device_context_->CreateStream(&stream_id_)) {
|
||||
if (!device_context_->device_res_manager_->CreateStream(&stream_id_)) {
|
||||
MS_LOG(EXCEPTION) << "Create stream failed.";
|
||||
}
|
||||
|
||||
|
@ -227,7 +227,7 @@ void EmbeddingCachePrefetchActor::BuildEmbeddingCacheLookupKernel() {
|
|||
|
||||
// 3. Kernel build process.
|
||||
MS_EXCEPTION_IF_NULL(device_context_);
|
||||
device_context_->CreateKernel({embedding_cache_lookup_node_});
|
||||
device_context_->kernel_executor_->CreateKernel({embedding_cache_lookup_node_});
|
||||
}
|
||||
|
||||
void EmbeddingCachePrefetchActor::BuildEmbeddingCacheUpdateKernel() {
|
||||
|
@ -252,7 +252,7 @@ void EmbeddingCachePrefetchActor::BuildEmbeddingCacheUpdateKernel() {
|
|||
|
||||
// 3. Kernel build process.
|
||||
MS_EXCEPTION_IF_NULL(device_context_);
|
||||
device_context_->CreateKernel({embedding_cache_update_node_});
|
||||
device_context_->kernel_executor_->CreateKernel({embedding_cache_update_node_});
|
||||
}
|
||||
|
||||
bool EmbeddingCachePrefetchActor::LookupDeviceCache(void *indices, void *embedding_cache, size_t indices_num,
|
||||
|
@ -288,7 +288,8 @@ bool EmbeddingCachePrefetchActor::LookupDeviceCache(void *indices, void *embeddi
|
|||
AddressPtrList kernel_outputs = {std::make_shared<Address>(outputs, indices_num * embedding_size * sizeof(float))};
|
||||
|
||||
MS_ERROR_IF_NULL(device_context_);
|
||||
auto ret = device_context_->LaunchKernel(embedding_cache_lookup_node_, kernel_inputs, {}, kernel_outputs);
|
||||
auto ret =
|
||||
device_context_->kernel_executor_->LaunchKernel(embedding_cache_lookup_node_, kernel_inputs, {}, kernel_outputs);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "Launch kernel: " << embedding_cache_lookup_node_->fullname_with_scope() << " failed.";
|
||||
return false;
|
||||
|
@ -337,7 +338,8 @@ bool EmbeddingCachePrefetchActor::UpdateDeviceCache(void *indices, void *update_
|
|||
std::make_shared<Address>(embedding_cache, cache_size * embedding_size * sizeof(float))};
|
||||
|
||||
MS_ERROR_IF_NULL(device_context_);
|
||||
auto ret = device_context_->LaunchKernel(embedding_cache_update_node_, kernel_inputs, {}, kernel_outputs);
|
||||
auto ret =
|
||||
device_context_->kernel_executor_->LaunchKernel(embedding_cache_update_node_, kernel_inputs, {}, kernel_outputs);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "Launch kernel: " << embedding_cache_update_node_->fullname_with_scope() << " failed.";
|
||||
return false;
|
||||
|
@ -807,7 +809,7 @@ bool EmbeddingCachePrefetchActor::PushCacheFromDeviceToLocalHost(const HashTable
|
|||
"Memcpy device to host asynchronously failed.");
|
||||
|
||||
MS_ERROR_IF_NULL(device_context_);
|
||||
RETURN_IF_FALSE_WITH_LOG(device_context_->SyncStream(stream_id_), "Synchronize stream failed.");
|
||||
RETURN_IF_FALSE_WITH_LOG(device_context_->device_res_manager_->SyncStream(stream_id_), "Synchronize stream failed.");
|
||||
RETURN_IF_FALSE_WITH_LOG(
|
||||
InsertLocalHostCache(embedding_size, IntToSize(swap_indices_size), host_cache_device_to_host_index,
|
||||
swap_out_data.get(), host_hash_table_addr),
|
||||
|
@ -880,7 +882,7 @@ bool EmbeddingCachePrefetchActor::PullCacheFromLocalHostToDevice(const HashTable
|
|||
swap_indices_size, cache_vocab_size, embedding_size, hash_table_addr),
|
||||
"Update device embedding cache failed.");
|
||||
MS_ERROR_IF_NULL(device_context_);
|
||||
RETURN_IF_FALSE_WITH_LOG(device_context_->SyncStream(stream_id_), "Synchronize stream failed.");
|
||||
RETURN_IF_FALSE_WITH_LOG(device_context_->device_res_manager_->SyncStream(stream_id_), "Synchronize stream failed.");
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -134,8 +134,8 @@ void KernelActor::FetchWorkspaceDeviceTensor() {
|
|||
(void)launch_info_.workspaces_.erase(launch_info_.workspaces_.end() - size, launch_info_.workspaces_.end());
|
||||
} else if (launch_info_.workspaces_.size() < workspace_sizes.size()) {
|
||||
for (size_t i = launch_info_.workspaces_.size(); i < workspace_sizes.size(); ++i) {
|
||||
auto device_address =
|
||||
device_contexts_[0]->CreateDeviceAddress(nullptr, workspace_sizes[i], "", kTypeUnknown, ShapeVector());
|
||||
auto device_address = device_contexts_[0]->device_res_manager_->CreateDeviceAddress(
|
||||
nullptr, workspace_sizes[i], "", kTypeUnknown, ShapeVector());
|
||||
MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(kernel_)
|
||||
<< " addr:" << device_address;
|
||||
AnfAlgo::SetWorkspaceAddr(device_address, i, kernel_.get()); // set to kernel_info
|
||||
|
@ -164,7 +164,7 @@ void AllocateMemory(const std::vector<DeviceTensor *> &alloc_list, const DeviceC
|
|||
continue;
|
||||
}
|
||||
// Allocate memory through the device context.
|
||||
if (!device_context->AllocateMemory(device_tensor, device_tensor->GetSize())) {
|
||||
if (!device_context->device_res_manager_->AllocateMemory(device_tensor, device_tensor->GetSize())) {
|
||||
SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(GraphExecutionStrategy::kStep, *context, *device_context, actor_name,
|
||||
device_tensor->GetSize());
|
||||
}
|
||||
|
@ -183,7 +183,7 @@ void FreeMemory(const std::vector<DeviceTensor *> &free_list, const DeviceContex
|
|||
if (device_tensor->ref_count() == 0) {
|
||||
// Free memory through the device context.
|
||||
if (device_tensor->GetPtr() != nullptr) {
|
||||
device_context->FreeMemory(device_tensor);
|
||||
device_context->device_res_manager_->FreeMemory(device_tensor);
|
||||
}
|
||||
device_tensor->ResetRefCount();
|
||||
}
|
||||
|
@ -219,7 +219,7 @@ void KernelActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
|
|||
// Free the address that is the temp store for kernel input copy.
|
||||
for (auto ©_input_device_tensor : copy_input_device_tensors_) {
|
||||
if ((copy_input_device_tensor != nullptr) && (copy_input_device_tensor->GetPtr() != nullptr)) {
|
||||
device_contexts_[0]->FreeMemory(copy_input_device_tensor.get());
|
||||
device_contexts_[0]->device_res_manager_->FreeMemory(copy_input_device_tensor.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -311,7 +311,7 @@ void KernelActor::CopyInputDeviceTensor(const OpData<DeviceTensor> *input_data,
|
|||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, *context, "The input index is of range.");
|
||||
}
|
||||
if (copy_input_device_tensors_[input_data->index_] == nullptr) {
|
||||
copy_input_device_tensors_[input_data->index_] = device_contexts_[0]->CreateDeviceAddress(
|
||||
copy_input_device_tensors_[input_data->index_] = device_contexts_[0]->device_res_manager_->CreateDeviceAddress(
|
||||
nullptr, real_input_info->size_, real_input_info->format_, real_input_info->type_id_, real_input_info->shape_);
|
||||
}
|
||||
auto &new_device_tensor = copy_input_device_tensors_[input_data->index_];
|
||||
|
@ -325,8 +325,8 @@ void KernelActor::CopyInputDeviceTensor(const OpData<DeviceTensor> *input_data,
|
|||
|
||||
device::DynamicMemAllocatorDebugInfo::SetDebugInfo(GetAID().Name(), device::AllocatorType::kKernelOutput,
|
||||
input_data->index_);
|
||||
if ((new_device_tensor->GetPtr() == nullptr) &&
|
||||
(!device_contexts_[0]->AllocateMemory(new_device_tensor.get(), new_device_tensor->GetSize()))) {
|
||||
if ((new_device_tensor->GetPtr() == nullptr) && (!device_contexts_[0]->device_res_manager_->AllocateMemory(
|
||||
new_device_tensor.get(), new_device_tensor->GetSize()))) {
|
||||
SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(strategy_, *context, *(device_contexts_[0]), GetAID().Name(),
|
||||
new_device_tensor->GetSize());
|
||||
}
|
||||
|
@ -451,8 +451,8 @@ void KernelActor::PreLaunchKernel(OpContext<DeviceTensor> *) {
|
|||
|
||||
bool KernelActor::LaunchKernel() {
|
||||
MS_EXCEPTION_IF_NULL(device_contexts_[0]);
|
||||
return device_contexts_[0]->LaunchKernel(kernel_, launch_info_.inputs_, launch_info_.workspaces_,
|
||||
launch_info_.outputs_);
|
||||
return device_contexts_[0]->kernel_executor_->LaunchKernel(kernel_, launch_info_.inputs_, launch_info_.workspaces_,
|
||||
launch_info_.outputs_);
|
||||
}
|
||||
|
||||
void KernelActor::PostLaunchKernel(OpContext<DeviceTensor> *const context) {
|
||||
|
|
|
@ -45,7 +45,7 @@ void MemoryManagerActor::AllocateMemory(const std::vector<DeviceTensor *> *alloc
|
|||
try {
|
||||
// Allocate memory through the device context.
|
||||
device::DynamicMemAllocatorDebugInfo::SetDebugInfo(from_aid.Name(), device::AllocatorType::kKernelOutput);
|
||||
if (!device_context->AllocateMemory(device_tensor, device_tensor->GetSize())) {
|
||||
if (!device_context->device_res_manager_->AllocateMemory(device_tensor, device_tensor->GetSize())) {
|
||||
SetOpContextMemoryAllocFail(from_aid.Name(), device_context, device_tensor->GetSize(), op_context);
|
||||
return;
|
||||
}
|
||||
|
@ -83,7 +83,7 @@ void MemoryManagerActor::AllocateContinuousMemory(const std::vector<std::vector<
|
|||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
// Allocate memory through the device context.
|
||||
device::DynamicMemAllocatorDebugInfo::SetDebugInfo(from_aid.Name(), device::AllocatorType::kKernelOutput);
|
||||
auto dev_ptr_list = device_context->AllocateContinuousMemory(size_list);
|
||||
auto dev_ptr_list = device_context->device_res_manager_->AllocateContinuousMemory(size_list);
|
||||
if (dev_ptr_list.empty() || dev_ptr_list.size() != alloc_list.size()) {
|
||||
MS_LOG(ERROR) << "Allocate continuous memory failed, device ptr list size: " << dev_ptr_list.size()
|
||||
<< ", address list size:" << alloc_list.size();
|
||||
|
@ -100,10 +100,10 @@ void MemoryManagerActor::AllocateContinuousMemory(const std::vector<std::vector<
|
|||
MS_LOG(EXCEPTION) << "Device size of old device address is larger than new device address, " << old_size
|
||||
<< " vs " << size_list[index];
|
||||
}
|
||||
auto new_dev_addr = device_context->CreateDeviceAddress(dev_ptr_list[index], old_size, old_dev_addr->format(),
|
||||
old_dev_addr->type_id(), old_dev_addr->host_shape());
|
||||
auto new_dev_addr = device_context->device_res_manager_->CreateDeviceAddress(
|
||||
dev_ptr_list[index], old_size, old_dev_addr->format(), old_dev_addr->type_id(), old_dev_addr->host_shape());
|
||||
new_dev_addr->SyncDeviceToDevice(old_dev_addr.get());
|
||||
device_context->FreeMemory(old_dev_addr.get());
|
||||
device_context->device_res_manager_->FreeMemory(old_dev_addr.get());
|
||||
}
|
||||
alloc_list[index]->set_ptr(dev_ptr_list[index]);
|
||||
alloc_list[index]->SetSize(size_list[index]);
|
||||
|
@ -138,7 +138,7 @@ void MemoryManagerActor::AllocateBatchMemory(const std::vector<DeviceTensor *> *
|
|||
try {
|
||||
// Allocate memory through the device context.
|
||||
device::DynamicMemAllocatorDebugInfo::SetDebugInfo(from_aid.Name(), device::AllocatorType::kKernelOutput);
|
||||
if (!device_context->AllocateMemory(device_tensor, device_tensor->GetSize())) {
|
||||
if (!device_context->device_res_manager_->AllocateMemory(device_tensor, device_tensor->GetSize())) {
|
||||
SetOpContextMemoryAllocFail(from_aid.Name(), device_context, device_tensor->GetSize(), op_context);
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -223,9 +223,9 @@ TensorPtr OutputActor::CreateOutputTensor(const AnfNodePtr &output_node, size_t
|
|||
if (output_node_to_tensor_device_address_.count({output_node, output_index}) > 0) {
|
||||
tensor->set_device_address(output_node_to_tensor_device_address_[{output_node, output_index}]);
|
||||
} else {
|
||||
auto tensor_device_address =
|
||||
device_context->CreateDeviceAddress(nullptr, device_tensor->GetSize(), device_tensor->format(),
|
||||
device_tensor->type_id(), device_tensor->host_shape());
|
||||
auto tensor_device_address = device_context->device_res_manager_->CreateDeviceAddress(
|
||||
nullptr, device_tensor->GetSize(), device_tensor->format(), device_tensor->type_id(),
|
||||
device_tensor->host_shape());
|
||||
MS_EXCEPTION_IF_NULL(tensor_device_address);
|
||||
tensor->set_device_address(tensor_device_address);
|
||||
output_node_to_tensor_device_address_[{output_node, output_index}] = tensor_device_address;
|
||||
|
@ -271,7 +271,8 @@ void OutputActor::UpdateOutputDeviceAddress() {
|
|||
auto device_context = device_contexts_[i];
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
device::DynamicMemAllocatorDebugInfo::SetDebugInfo(GetAID().Name(), device::AllocatorType::kOther);
|
||||
if (!device_context->AllocateMemory(tensor_device_address.get(), tensor_device_address->GetSize())) {
|
||||
if (!device_context->device_res_manager_->AllocateMemory(tensor_device_address.get(),
|
||||
tensor_device_address->GetSize())) {
|
||||
MS_LOG(EXCEPTION) << "Device(id:" << device_context->device_context_key().device_id_
|
||||
<< ") memory isn't enough and alloc failed, kernel name: "
|
||||
<< output_node->fullname_with_scope() << ", alloc size: " << tensor_device_address->GetSize()
|
||||
|
|
|
@ -77,7 +77,11 @@ void SuperKernelActor::Run(OpContext<DeviceTensor> *const context) {
|
|||
}
|
||||
|
||||
try {
|
||||
auto ret = device_contexts_[0]->LaunchGraph(graph_);
|
||||
// @TODO: @TBD: run graph with inputs and outputs
|
||||
const std::vector<tensor::Tensor> inputs;
|
||||
std::vector<tensor::Tensor> outputs;
|
||||
const std::map<string, string> compile_options;
|
||||
auto ret = device_contexts_[0]->graph_executor_->RunGraph(graph_, inputs, &outputs, compile_options);
|
||||
if (!ret) {
|
||||
std::string error_info = "Launch graph failed, graph id: " + std::to_string(graph_->graph_id());
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
|
|
|
@ -258,8 +258,8 @@ void CreateDeviceTensorForValueNode(const KernelWithIndex &front_node_with_index
|
|||
|
||||
// Create device tensor.
|
||||
std::string output_format = AnfAlgo::GetOutputFormat(backend_node, 0);
|
||||
device::DeviceAddressPtr address =
|
||||
device_context->CreateDeviceAddress(nullptr, tensor_size, output_format, output_type_id, ShapeVector());
|
||||
device::DeviceAddressPtr address = device_context->device_res_manager_->CreateDeviceAddress(
|
||||
nullptr, tensor_size, output_format, output_type_id, ShapeVector());
|
||||
MS_EXCEPTION_IF_NULL(address);
|
||||
MS_LOG(DEBUG) << "Create address for node:" << common::AnfAlgo::GetNodeDebugString(front_node) << " addr:" << address
|
||||
<< " size:" << tensor_size;
|
||||
|
@ -316,7 +316,7 @@ void CreateDeviceTensorForFrontNode(const KernelWithIndex &front_node_with_index
|
|||
size_t size = 0;
|
||||
size = AnfAlgo::GetOutputTensorMemSize(node, front_node_with_index.second);
|
||||
device::DeviceAddressPtr address =
|
||||
device_context->CreateDeviceAddress(nullptr, size, kOpFormat_DEFAULT, type_id, ShapeVector());
|
||||
device_context->device_res_manager_->CreateDeviceAddress(nullptr, size, kOpFormat_DEFAULT, type_id, ShapeVector());
|
||||
MS_EXCEPTION_IF_NULL(address);
|
||||
MS_LOG(INFO) << "Create address for node that has no corresponding backend node:"
|
||||
<< common::AnfAlgo::GetNodeDebugString(node) << " addr:" << address << " size:" << size
|
||||
|
|
|
@ -312,9 +312,9 @@ void ControlNodeScheduler::BuildDataSourceActorForControlNode(const GraphCompile
|
|||
// Create device tensor.
|
||||
const auto &device_address = AnfAlgo::GetMutableOutputAddr(backend_node, 0, false);
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
auto new_address =
|
||||
device_context->CreateDeviceAddress(nullptr, device_address->GetSize(), device_address->format(),
|
||||
device_address->type_id(), device_address->host_shape());
|
||||
auto new_address = device_context->device_res_manager_->CreateDeviceAddress(
|
||||
nullptr, device_address->GetSize(), device_address->format(), device_address->type_id(),
|
||||
device_address->host_shape());
|
||||
MS_EXCEPTION_IF_NULL(new_address);
|
||||
MS_LOG(INFO) << "Create new address for node that has no corresponding backend node:"
|
||||
<< common::AnfAlgo::GetNodeDebugString(parameter.first) << " addr:" << new_address
|
||||
|
|
|
@ -102,9 +102,9 @@ void CreateParameterDeviceAddress(const DeviceContext *device_context, const Ker
|
|||
}
|
||||
|
||||
size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
|
||||
auto device_address =
|
||||
device_context->CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id,
|
||||
trans::GetRuntimePaddingShape(item, index));
|
||||
auto device_address = device_context->device_res_manager_->CreateDeviceAddress(
|
||||
nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id,
|
||||
trans::GetRuntimePaddingShape(item, index));
|
||||
device_address->set_from_persistent_mem(item->isa<Parameter>());
|
||||
MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(item)
|
||||
<< " addr:" << device_address;
|
||||
|
@ -144,7 +144,7 @@ void CreateDeviceAddressForTensorValue(const DeviceContext *device_context, cons
|
|||
}
|
||||
std::string output_format = AnfAlgo::GetOutputFormat(value_node, output_idx);
|
||||
|
||||
device::DeviceAddressPtr address = device_context->CreateDeviceAddress(
|
||||
device::DeviceAddressPtr address = device_context->device_res_manager_->CreateDeviceAddress(
|
||||
nullptr, tensor_size, output_format, output_type_id, trans::GetRuntimePaddingShape(value_node, output_idx));
|
||||
MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(value_node) << " addr:" << address;
|
||||
MS_EXCEPTION_IF_NULL(address);
|
||||
|
@ -169,8 +169,8 @@ void CreateValueNodeDeviceAddress(const DeviceContext *device_context, const Ker
|
|||
} else if (node_value->isa<StringImm>()) {
|
||||
auto value = GetValue<std::string>(node_value);
|
||||
size_t tensor_size = value.size();
|
||||
auto address =
|
||||
device_context->CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8, ShapeVector());
|
||||
auto address = device_context->device_res_manager_->CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT,
|
||||
kNumberTypeUInt8, ShapeVector());
|
||||
MS_EXCEPTION_IF_NULL(address);
|
||||
address->set_from_persistent_mem(true);
|
||||
MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(value_node)
|
||||
|
@ -200,8 +200,8 @@ void CreateKernelOutputDeviceAddress(const DeviceContext *device_context, const
|
|||
auto output_format = AnfAlgo::GetOutputFormat(kernel, i);
|
||||
auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
|
||||
auto address_size = AnfAlgo::GetOutputTensorMemSize(kernel, i);
|
||||
auto device_address = device_context->CreateDeviceAddress(nullptr, address_size, output_format, output_type,
|
||||
trans::GetRuntimePaddingShape(kernel, i));
|
||||
auto device_address = device_context->device_res_manager_->CreateDeviceAddress(
|
||||
nullptr, address_size, output_format, output_type, trans::GetRuntimePaddingShape(kernel, i));
|
||||
if (is_gradient_out) {
|
||||
device_address->set_from_persistent_mem(true);
|
||||
}
|
||||
|
@ -228,8 +228,8 @@ void CreateKernelWorkspaceDeviceAddress(const DeviceContext *device_context, con
|
|||
if (AnfAlgo::WorkspaceAddrExist(kernel, i)) {
|
||||
break;
|
||||
}
|
||||
auto device_address =
|
||||
device_context->CreateDeviceAddress(nullptr, workspace_sizes[i], "", kTypeUnknown, ShapeVector());
|
||||
auto device_address = device_context->device_res_manager_->CreateDeviceAddress(nullptr, workspace_sizes[i], "",
|
||||
kTypeUnknown, ShapeVector());
|
||||
MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(kernel)
|
||||
<< " addr:" << device_address;
|
||||
AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get());
|
||||
|
@ -528,7 +528,13 @@ GraphId GraphCompiler::CompileGraph(const GraphSegmentPtr &segment, const AnfNod
|
|||
SetGraphDependency(graph, segment);
|
||||
|
||||
// Unify the MindIR, must be before of the graph optimization.
|
||||
device_context->UnifyMindIR(graph);
|
||||
auto deprecated_kernel_executor =
|
||||
dynamic_cast<device::DeprecatedKernelExecutor *>(device_context->kernel_executor_.get());
|
||||
if (deprecated_kernel_executor != nullptr) {
|
||||
deprecated_kernel_executor->UnifyMindIR(graph);
|
||||
} else {
|
||||
opt::CommonUnifyMindIR(graph);
|
||||
}
|
||||
|
||||
// The graph common optimization.
|
||||
graph->UpdateGraphAquireGilAttr();
|
||||
|
@ -612,7 +618,11 @@ GraphId GraphCompiler::CompileWholeGraphForGraphRunMode(const FuncGraphPtr &func
|
|||
root_graph->set_is_loop_count_sink(true);
|
||||
|
||||
// Unify the MindIR, must be before of the graph optimization.
|
||||
device_context->UnifyMindIR(root_graph);
|
||||
auto deprecated_kernel_executor =
|
||||
dynamic_cast<device::DeprecatedKernelExecutor *>(device_context->kernel_executor_.get());
|
||||
if (deprecated_kernel_executor != nullptr) {
|
||||
deprecated_kernel_executor->UnifyMindIR(root_graph);
|
||||
}
|
||||
|
||||
// The graph common optimization.
|
||||
opt::BackendCommonOptimization(root_graph);
|
||||
|
@ -654,11 +664,11 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic
|
|||
#endif
|
||||
|
||||
// Execute optimization pass.
|
||||
device_context->OptimizeGraph(graph);
|
||||
device_context->kernel_executor_->OptimizeGraph(graph);
|
||||
|
||||
// Generate 'KernelMod' for all kernels and set 'KernelMod' into kernel,
|
||||
// 'KernelMod' is real executive object of kernel.
|
||||
device_context->CreateKernel(graph->execution_order());
|
||||
device_context->kernel_executor_->CreateKernel(graph->execution_order());
|
||||
|
||||
// Read the output and input ref map and set to the kernel graph.
|
||||
AddOutInRefToGraph(graph);
|
||||
|
@ -670,7 +680,7 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic
|
|||
#endif
|
||||
|
||||
// Adjust kernel graph before run graph.
|
||||
device_context->PreprocessBeforeRun(graph);
|
||||
device_context->kernel_executor_->PreprocessBeforeRun(graph);
|
||||
|
||||
// Create device address for all anf nodes of graph.
|
||||
CreateDeviceAddress(graph, device_context, false);
|
||||
|
@ -724,10 +734,16 @@ GraphId GraphCompiler::CompileGraph(const session::OpRunInfo &op_run_info, bool
|
|||
graph->set_run_mode(device::RunMode::kKernelMode);
|
||||
graph->set_is_from_single_op(true);
|
||||
// session_ is SessionBasic, AscendUnifyMindIR has not been executed.
|
||||
device_context->UnifyMindIR(graph);
|
||||
auto deprecated_kernel_executor =
|
||||
dynamic_cast<device::DeprecatedKernelExecutor *>(device_context->kernel_executor_.get());
|
||||
if (deprecated_kernel_executor != nullptr) {
|
||||
deprecated_kernel_executor->UnifyMindIR(graph);
|
||||
} else {
|
||||
opt::CommonUnifyMindIR(graph);
|
||||
}
|
||||
|
||||
// Select kernel and optimize
|
||||
device_context->OptimizeGraph(graph);
|
||||
device_context->kernel_executor_->OptimizeGraph(graph);
|
||||
|
||||
UpdateRefInfoBeforeCreateKernel(op_run_info, graph);
|
||||
|
||||
|
@ -772,10 +788,10 @@ void GraphCompiler::BuildSingleOpGraphs(const std::vector<KernelGraphPtr> &graph
|
|||
std::copy(nodes.begin(), nodes.end(), std::back_inserter(node_to_build));
|
||||
}
|
||||
// Kernel build
|
||||
device_context->CreateKernel(node_to_build);
|
||||
device_context->kernel_executor_->CreateKernel(node_to_build);
|
||||
|
||||
for (const auto &graph : graphs) {
|
||||
device_context->PreprocessBeforeRun(graph);
|
||||
device_context->kernel_executor_->PreprocessBeforeRun(graph);
|
||||
CreateKernelWorkspaceDeviceAddress(device_context, graph);
|
||||
// Need to execute after PreprocessBeforeRunSingleOpGraph
|
||||
runtime::OpRuntimeInfo::CacheGraphOpRuntimeInfo(graph);
|
||||
|
|
|
@ -2099,9 +2099,9 @@ void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler
|
|||
if (DeviceTensorStore::GetInstance().Fetch(front_node.get(), device_context->GetDeviceType()) == nullptr) {
|
||||
MS_LOG(WARNING) << "Fetch no device tensor store by:" << front_node->fullname_with_scope()
|
||||
<< ", type:" << device_context->GetDeviceType();
|
||||
auto other_type_device_tensor =
|
||||
device_context->CreateDeviceAddress(nullptr, device_tensor->GetSize(), device_tensor->format(),
|
||||
device_tensor->type_id(), device_tensor->host_shape());
|
||||
auto other_type_device_tensor = device_context->device_res_manager_->CreateDeviceAddress(
|
||||
nullptr, device_tensor->GetSize(), device_tensor->format(), device_tensor->type_id(),
|
||||
device_tensor->host_shape());
|
||||
other_type_device_tensor->SetNodeIndex(input_node, 0);
|
||||
other_type_device_tensor->set_from_persistent_mem(input_node->isa<Parameter>());
|
||||
SchedulerHelper::AddDeviceTensorStore(front_node.get(), other_type_device_tensor);
|
||||
|
@ -2143,9 +2143,9 @@ void GraphScheduler::PersistDeviceTensorForRootGraphControlNode(const GraphCompi
|
|||
auto sub_device_tensor = AnfAlgo::GetMutableOutputAddr(backend_node, 0, false);
|
||||
MS_EXCEPTION_IF_NULL(sub_device_tensor);
|
||||
|
||||
auto new_device_tensor =
|
||||
device_context->CreateDeviceAddress(nullptr, sub_device_tensor->GetSize(), sub_device_tensor->format(),
|
||||
sub_device_tensor->type_id(), sub_device_tensor->host_shape());
|
||||
auto new_device_tensor = device_context->device_res_manager_->CreateDeviceAddress(
|
||||
nullptr, sub_device_tensor->GetSize(), sub_device_tensor->format(), sub_device_tensor->type_id(),
|
||||
sub_device_tensor->host_shape());
|
||||
MS_EXCEPTION_IF_NULL(new_device_tensor);
|
||||
new_device_tensor->SetNodeIndex(backend_node, 0);
|
||||
new_device_tensor->set_is_ptr_persisted(sub_device_tensor->is_ptr_persisted());
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
bool DeviceContext::CreateStream(size_t *stream_id) {
|
||||
bool DeviceResManager::CreateStream(size_t *stream_id) {
|
||||
MS_EXCEPTION_IF_NULL(stream_id);
|
||||
std::lock_guard<std::mutex> locker(stream_mutex_);
|
||||
void *stream = nullptr;
|
||||
|
@ -38,7 +38,7 @@ bool DeviceContext::CreateStream(size_t *stream_id) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool DeviceContext::DestroyAllStreams() {
|
||||
bool DeviceResManager::DestroyAllStreams() {
|
||||
for (auto &item : stream_ids_) {
|
||||
if (item.second != nullptr) {
|
||||
if (!DestroyStream(item.second)) {
|
||||
|
@ -53,7 +53,7 @@ bool DeviceContext::DestroyAllStreams() {
|
|||
return true;
|
||||
}
|
||||
|
||||
void *DeviceContext::GetStream(size_t stream_id) const {
|
||||
void *DeviceResManager::GetStream(size_t stream_id) const {
|
||||
auto iter = stream_ids_.find(stream_id);
|
||||
if (iter == stream_ids_.end()) {
|
||||
MS_LOG(ERROR) << "Can not find stream for stream id[" << stream_id << "]";
|
||||
|
@ -63,12 +63,12 @@ void *DeviceContext::GetStream(size_t stream_id) const {
|
|||
return iter->second;
|
||||
}
|
||||
|
||||
bool DeviceContext::AllocateMemory(DeviceAddress *const &address, size_t size) const {
|
||||
bool DeviceResManager::AllocateMemory(DeviceAddress *const &address, size_t size) const {
|
||||
MS_EXCEPTION_IF_NULL(address);
|
||||
auto device_name_in_address = GetDeviceNameByType(static_cast<const DeviceType>(address->GetDeviceType()));
|
||||
if (device_name_in_address != device_context_key_.device_name_) {
|
||||
if (device_name_in_address != device_context_->device_context_key().device_name_) {
|
||||
MS_LOG(EXCEPTION) << "The device address type is wrong: type name in address:" << device_name_in_address
|
||||
<< ", type name in context:" << device_context_key_.device_name_;
|
||||
<< ", type name in context:" << device_context_->device_context_key().device_name_;
|
||||
}
|
||||
|
||||
if (address->GetPtr() != nullptr) {
|
||||
|
@ -85,16 +85,16 @@ bool DeviceContext::AllocateMemory(DeviceAddress *const &address, size_t size) c
|
|||
return true;
|
||||
}
|
||||
|
||||
void DeviceContext::FreeMemory(DeviceAddress *const &address) const {
|
||||
void DeviceResManager::FreeMemory(DeviceAddress *const &address) const {
|
||||
MS_EXCEPTION_IF_NULL(address);
|
||||
if (address->GetPtr() == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Device ptr is null in device address to release!";
|
||||
}
|
||||
|
||||
auto device_name_in_address = GetDeviceNameByType(static_cast<const DeviceType>(address->GetDeviceType()));
|
||||
if (device_name_in_address != device_context_key_.device_name_) {
|
||||
if (device_name_in_address != device_context_->device_context_key().device_name_) {
|
||||
MS_LOG(EXCEPTION) << "The device address type is wrong: type name in address:" << device_name_in_address
|
||||
<< ", type name in context:" << device_context_key_.device_name_;
|
||||
<< ", type name in context:" << device_context_->device_context_key().device_name_;
|
||||
}
|
||||
|
||||
if (!address->from_mem_pool()) {
|
||||
|
|
|
@ -49,11 +49,14 @@ struct DeviceContextKey {
|
|||
std::string ToString() const { return device_name_ + "_" + std::to_string(device_id_); }
|
||||
};
|
||||
|
||||
class DeviceResManager;
|
||||
class GraphExecutor;
|
||||
class KernelExecutor;
|
||||
|
||||
// DeviceContext is unified interface of interaction with device.
|
||||
class DeviceContext {
|
||||
public:
|
||||
explicit DeviceContext(const DeviceContextKey &device_context_key)
|
||||
: device_context_key_(device_context_key), collective_comm_lib_(nullptr) {}
|
||||
explicit DeviceContext(const DeviceContextKey &device_context_key) : device_context_key_(device_context_key) {}
|
||||
virtual ~DeviceContext() = default;
|
||||
|
||||
// Initialize the device context.
|
||||
|
@ -67,49 +70,54 @@ class DeviceContext {
|
|||
virtual bool PartitionGraph(const FuncGraphPtr &func_graph) const { return false; }
|
||||
|
||||
// Analysis the function graph and select the appropriate run mode for the graph
|
||||
virtual RunMode GetRunMode(const FuncGraphPtr &func_graph) const { return RunMode::kKernelMode; }
|
||||
virtual RunMode GetRunMode(const FuncGraphPtr &func_graph) const = 0;
|
||||
|
||||
// Get device_context_key_ to obtain device name and device id.
|
||||
const DeviceContextKey &device_context_key() const { return device_context_key_; }
|
||||
|
||||
// Get device address type according different device type, such GPU, Ascend.
|
||||
DeviceType GetDeviceType() const { return GetDeviceTypeByName(device_context_key_.device_name_); }
|
||||
|
||||
DeviceContextKey device_context_key_;
|
||||
std::unique_ptr<DeviceResManager> device_res_manager_;
|
||||
std::unique_ptr<GraphExecutor> graph_executor_;
|
||||
std::unique_ptr<KernelExecutor> kernel_executor_;
|
||||
};
|
||||
using DeviceContextPtr = std::shared_ptr<DeviceContext>;
|
||||
|
||||
class DeviceResManager {
|
||||
public:
|
||||
DeviceResManager() : collective_comm_lib_(nullptr) {}
|
||||
virtual ~DeviceResManager() = default;
|
||||
|
||||
// Initialize the device resource manager.
|
||||
virtual void Initialize() {}
|
||||
|
||||
// Destroy device resource manager and release device resource.
|
||||
virtual void Destroy() {}
|
||||
|
||||
// Bind device to current thread to gain device control privileges
|
||||
virtual bool BindDeviceToCurrentThread() const { return true; }
|
||||
|
||||
// Relevant function to allocate and free device memory of raw ptr.
|
||||
virtual void *AllocateMemory(size_t size) const = 0;
|
||||
virtual void FreeMemory(void *ptr) const = 0;
|
||||
|
||||
// Relevant function to allocate and free device memory of DeviceAddress.
|
||||
bool AllocateMemory(DeviceAddress *const &address, size_t size) const;
|
||||
void FreeMemory(DeviceAddress *const &address) const;
|
||||
|
||||
// Allocate continuous device memory according to size list.
|
||||
// Communication operators may need continuous memory for input and output
|
||||
// to optimize the communication performance.
|
||||
virtual std::vector<void *> AllocateContinuousMemory(const std::vector<size_t> &size_list) const {
|
||||
std::vector<void *> ptr_list;
|
||||
return ptr_list;
|
||||
MS_LOG(EXCEPTION) << "Unimplemented interface.";
|
||||
}
|
||||
|
||||
// Create concrete device address according different device type.
|
||||
virtual DeviceAddressPtr CreateDeviceAddress(void *const device_ptr, size_t device_size, const string &format,
|
||||
TypeId type_id, const ShapeVector &shape) const = 0;
|
||||
|
||||
// Unify the MindIR, the default behavior uses the common unified MindIR.
|
||||
virtual void UnifyMindIR(const KernelGraphPtr &graph) const { opt::CommonUnifyMindIR(graph); }
|
||||
|
||||
// Optimize the kernel graph for graph mode.
|
||||
virtual void OptimizeGraph(const FuncGraphPtr &graph) const {}
|
||||
|
||||
// Generate 'KernelMod' for all kernels and set 'KernelMod' into kernel,
|
||||
// 'KernelMod' is real executive object of kernel.
|
||||
virtual void CreateKernel(const std::vector<CNodePtr> &nodes) const = 0;
|
||||
|
||||
// Adjust kernel graph before run graph.
|
||||
virtual void PreprocessBeforeRun(const FuncGraphPtr &graph) const {}
|
||||
|
||||
// Launch graph, device such as Ascend support the whole graph sink to the device executing.
|
||||
virtual bool LaunchGraph(const KernelGraphPtr &graph) const { return true; }
|
||||
|
||||
// Launch a kernel via 'KernelMod' of the kernel.
|
||||
virtual bool LaunchKernel(const CNodePtr &kernel, const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Bind device to current thread to gain device control privileges
|
||||
virtual bool BindDeviceToCurrentThread() const { return true; }
|
||||
|
||||
// Create a stream with assigning a stream id, the assigned stream id will be written to the variable '*stream_id';
|
||||
bool CreateStream(size_t *stream_id);
|
||||
|
||||
|
@ -126,15 +134,6 @@ class DeviceContext {
|
|||
// Get physical stream based on logical stream id.
|
||||
void *GetStream(size_t stream_id) const;
|
||||
|
||||
// Get rank id for distributed training.
|
||||
// It is deprecated and will be removed in a future version
|
||||
virtual uint32_t GetRankID() const { return 0; }
|
||||
|
||||
// Create and initialize bucket for every allreduce operator. Bucket is used in PyNative distributed training mode,
|
||||
// one bucket handles all resource to launch and sync allreduce operator.
|
||||
// It is deprecated and will be removed in a future version
|
||||
virtual std::shared_ptr<Bucket> CreateBucket(uint32_t bucket_id, uint32_t bucket_size) const { return nullptr; }
|
||||
|
||||
// Dynamically load collective communication library.
|
||||
// Currently, four types are supported: OpenMPI and self developed framework for CPU. NCCL for GPU. HCCL for Ascend.
|
||||
virtual bool LoadCollectiveCommLib() { return true; }
|
||||
|
@ -142,25 +141,12 @@ class DeviceContext {
|
|||
// Return collective communication object for caller to access
|
||||
CollectiveCommunicationLib *collective_comm_lib() const { return collective_comm_lib_; }
|
||||
|
||||
// Get device_context_key_ to obtain device name and device id.
|
||||
const DeviceContextKey &device_context_key() const { return device_context_key_; }
|
||||
|
||||
// Get device address type according different device type, such GPU, Ascend.
|
||||
DeviceType GetDeviceType() const { return GetDeviceTypeByName(device_context_key_.device_name_); }
|
||||
|
||||
// Relevant function to allocate and free device memory of DeviceAddress.
|
||||
bool AllocateMemory(DeviceAddress *const &address, size_t size) const;
|
||||
void FreeMemory(DeviceAddress *const &address) const;
|
||||
|
||||
protected:
|
||||
// Create a stream on the device of this device context.
|
||||
virtual bool CreateStream(void **stream) const { return true; }
|
||||
|
||||
// Destroy a stream on the device of this device context.
|
||||
virtual bool DestroyStream(void *stream) const { return true; }
|
||||
|
||||
DeviceContextKey device_context_key_;
|
||||
|
||||
// Record stream ids to stream address, key: stream id, value: address of stream.
|
||||
std::map<size_t, void *> stream_ids_;
|
||||
|
||||
|
@ -172,8 +158,124 @@ class DeviceContext {
|
|||
|
||||
// The collective communication library.
|
||||
CollectiveCommunicationLib *collective_comm_lib_;
|
||||
|
||||
DeviceContext *device_context_;
|
||||
|
||||
private:
|
||||
template <class... Args>
|
||||
friend class DeviceInterface;
|
||||
|
||||
void SetDeviceContext(DeviceContext *device_context) { device_context_ = device_context; }
|
||||
};
|
||||
|
||||
class GraphExecutor {
|
||||
public:
|
||||
virtual ~GraphExecutor() = default;
|
||||
virtual bool CompileGraph(const FuncGraphPtr &graph, const std::map<string, string> &compile_options) { return true; }
|
||||
virtual bool RunGraph(const FuncGraphPtr &graph, const std::vector<tensor::Tensor> &inputs,
|
||||
std::vector<tensor::Tensor> *outputs, const std::map<string, string> &compile_options) {
|
||||
MS_LOG(EXCEPTION) << "Unimplemented interface.";
|
||||
}
|
||||
|
||||
protected:
|
||||
DeviceContext *device_context_;
|
||||
|
||||
private:
|
||||
template <class... Args>
|
||||
friend class DeviceInterface;
|
||||
|
||||
void SetDeviceContext(DeviceContext *device_context) { device_context_ = device_context; }
|
||||
};
|
||||
|
||||
class KernelExecutor {
|
||||
public:
|
||||
virtual ~KernelExecutor() = default;
|
||||
|
||||
// Optimize the kernel graph for graph mode.
|
||||
virtual void OptimizeGraph(const FuncGraphPtr &graph) const {}
|
||||
|
||||
// Generate 'KernelMod' for all kernels and set 'KernelMod' into kernel,
|
||||
// 'KernelMod' is real executive object of kernel.
|
||||
virtual void CreateKernel(const std::vector<CNodePtr> &nodes) const {}
|
||||
|
||||
// Adjust kernel graph before run graph.
|
||||
virtual void PreprocessBeforeRun(const FuncGraphPtr &graph) const {}
|
||||
|
||||
// Launch a kernel via 'KernelMod' of the kernel.
|
||||
virtual bool LaunchKernel(const CNodePtr &kernel, const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs) const {
|
||||
MS_LOG(EXCEPTION) << "Unimplemented interface.";
|
||||
}
|
||||
|
||||
protected:
|
||||
DeviceContext *device_context_;
|
||||
|
||||
private:
|
||||
template <class... Args>
|
||||
friend class DeviceInterface;
|
||||
|
||||
void SetDeviceContext(DeviceContext *device_context) { device_context_ = device_context; }
|
||||
};
|
||||
|
||||
class DeprecatedKernelExecutor : public KernelExecutor {
|
||||
public:
|
||||
// Unify the MindIR, the default behavior uses the common unified MindIR.
|
||||
// It is deprecated and will be removed in a future version
|
||||
virtual void UnifyMindIR(const KernelGraphPtr &graph) const { opt::CommonUnifyMindIR(graph); }
|
||||
|
||||
// Get rank id for distributed training.
|
||||
// It is deprecated and will be removed in a future version
|
||||
virtual uint32_t GetRankID() const { return 0; }
|
||||
|
||||
// Create and initialize bucket for every allreduce operator. Bucket is used in PyNative distributed training mode,
|
||||
// one bucket handles all resource to launch and sync allreduce operator.
|
||||
// It is deprecated and will be removed in a future version
|
||||
virtual std::shared_ptr<Bucket> CreateBucket(uint32_t bucket_id, uint32_t bucket_size) const { return nullptr; }
|
||||
};
|
||||
|
||||
template <class... Args>
|
||||
class DeviceInterface : public DeviceContext {};
|
||||
|
||||
template <>
|
||||
class DeviceInterface<> : public DeviceContext {
|
||||
public:
|
||||
explicit DeviceInterface(const DeviceContextKey &key) : DeviceContext(key) {}
|
||||
|
||||
protected:
|
||||
void CheckUnset(void *ptr, const std::string &error_msg) {
|
||||
if (ptr != nullptr) {
|
||||
MS_LOG(EXCEPTION) << error_msg;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <class T, class... Args>
|
||||
class DeviceInterface<T, Args...> : public DeviceInterface<Args...> {
|
||||
public:
|
||||
explicit DeviceInterface(const DeviceContextKey &key) : DeviceInterface<Args...>(key) {
|
||||
if constexpr (std::is_base_of_v<DeviceResManager, T>) {
|
||||
DeviceInterface::CheckUnset(reinterpret_cast<void *>(DeviceContext::device_res_manager_.get()),
|
||||
"DeviceResManager has been registered!");
|
||||
DeviceContext::device_res_manager_ = std::make_unique<T>();
|
||||
DeviceContext::device_res_manager_->SetDeviceContext(this);
|
||||
} else if constexpr (std::is_base_of_v<GraphExecutor, T>) {
|
||||
DeviceInterface::CheckUnset(reinterpret_cast<void *>(DeviceContext::graph_executor_.get()),
|
||||
"GraphExecutor has been registered!");
|
||||
DeviceContext::graph_executor_ = std::make_unique<T>();
|
||||
DeviceContext::graph_executor_->SetDeviceContext(this);
|
||||
} else if constexpr (std::is_base_of_v<KernelExecutor, T>) {
|
||||
DeviceInterface::CheckUnset(reinterpret_cast<void *>(DeviceContext::kernel_executor_.get()),
|
||||
"KernelExecutor has been registered!");
|
||||
DeviceContext::kernel_executor_ = std::make_unique<T>();
|
||||
DeviceContext::kernel_executor_->SetDeviceContext(this);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename = std::enable_if_t<std::is_base_of_v<DeviceResManager, T> || std::is_base_of_v<GraphExecutor, T> ||
|
||||
std::is_base_of_v<KernelExecutor, T>>>
|
||||
void Assert() {}
|
||||
};
|
||||
using DeviceContextPtr = std::shared_ptr<DeviceContext>;
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_DEVICE_CONTEXT_H_
|
||||
|
|
|
@ -33,7 +33,7 @@ void DeviceContextManager::ClearDeviceContexts() {
|
|||
for (auto &iter : device_contexts_) {
|
||||
MS_LOG(INFO) << "Release device " << iter.first;
|
||||
MS_EXCEPTION_IF_NULL(iter.second);
|
||||
(void)iter.second->DestroyAllStreams();
|
||||
(void)iter.second->device_res_manager_->DestroyAllStreams();
|
||||
iter.second->Destroy();
|
||||
}
|
||||
device_contexts_.clear();
|
||||
|
@ -77,7 +77,7 @@ void DeviceContextManager::WaitTaskFinishOnDevice() const {
|
|||
for (const auto &item : device_contexts_) {
|
||||
auto device_context = item.second;
|
||||
try {
|
||||
if (device_context != nullptr && !device_context->SyncStream()) {
|
||||
if (device_context != nullptr && !device_context->device_res_manager_->SyncStream()) {
|
||||
MS_LOG(ERROR) << "SyncStream failed";
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -141,7 +141,7 @@ void CopyTensorDataToDevice(const tensor::TensorPtr &tensor, const AnfNodePtr &n
|
|||
auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
if ((device_address->GetPtr() == nullptr) &&
|
||||
(!device_context->AllocateMemory(device_address.get(), device_address->GetSize()))) {
|
||||
(!device_context->device_res_manager_->AllocateMemory(device_address.get(), device_address->GetSize()))) {
|
||||
MS_LOG(EXCEPTION) << "Allocate memory failed";
|
||||
}
|
||||
// Copy data from host tensor to device.
|
||||
|
@ -186,7 +186,7 @@ void CopyValueNodeStringToDevice(const ValueNodePtr &node, const device::DeviceC
|
|||
return;
|
||||
}
|
||||
|
||||
if (!device_context->AllocateMemory(node_address.get(), node_address->GetSize())) {
|
||||
if (!device_context->device_res_manager_->AllocateMemory(node_address.get(), node_address->GetSize())) {
|
||||
MS_LOG(EXCEPTION) << "Allocate memory failed";
|
||||
}
|
||||
|
||||
|
@ -266,7 +266,7 @@ bool MallocForKernelInput(const std::shared_ptr<OpRuntimeInfo> &runtime_info,
|
|||
auto input_address = runtime_info->GetInputDeviceAddress(i);
|
||||
MS_EXCEPTION_IF_NULL(input_address);
|
||||
if (input_address->GetPtr() == nullptr &&
|
||||
!device_context->AllocateMemory(input_address.get(), input_address->GetSize())) {
|
||||
!device_context->device_res_manager_->AllocateMemory(input_address.get(), input_address->GetSize())) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -306,7 +306,7 @@ bool MallocForKernelOutput(const std::shared_ptr<OpRuntimeInfo> &runtime_info, c
|
|||
}
|
||||
}
|
||||
if (device_address->GetPtr() == nullptr &&
|
||||
!device_context->AllocateMemory(device_address.get(), device_address->GetSize())) {
|
||||
!device_context->device_res_manager_->AllocateMemory(device_address.get(), device_address->GetSize())) {
|
||||
MS_LOG(ERROR) << "Allocate output memory failed, node:" << node->fullname_with_scope();
|
||||
return false;
|
||||
}
|
||||
|
@ -341,8 +341,8 @@ kernel::AddressPtrList CreateKernelWorkspaceAddress(const std::shared_ptr<OpRunt
|
|||
// Resize of workspaces, because of the dynamic size of workspace.
|
||||
if (workspace_size < workspace_sizes.size()) {
|
||||
for (size_t i = workspace_size; i < workspace_sizes.size(); ++i) {
|
||||
auto device_address =
|
||||
device_context->CreateDeviceAddress(nullptr, workspace_sizes[i], "", kTypeUnknown, ShapeVector());
|
||||
auto device_address = device_context->device_res_manager_->CreateDeviceAddress(nullptr, workspace_sizes[i], "",
|
||||
kTypeUnknown, ShapeVector());
|
||||
MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(kernel)
|
||||
<< " addr:" << device_address;
|
||||
AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get()); // set to kernel_info
|
||||
|
@ -364,7 +364,7 @@ kernel::AddressPtrList CreateKernelWorkspaceAddress(const std::shared_ptr<OpRunt
|
|||
auto device_address = runtime_info->GetWorkspaceDeviceAddress(i);
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
if (device_address->GetPtr() == nullptr &&
|
||||
!device_context->AllocateMemory(device_address.get(), device_address->GetSize())) {
|
||||
!device_context->device_res_manager_->AllocateMemory(device_address.get(), device_address->GetSize())) {
|
||||
MS_LOG(EXCEPTION) << "Allocate workspace memory failed";
|
||||
}
|
||||
workspaces.emplace_back(
|
||||
|
@ -376,7 +376,7 @@ kernel::AddressPtrList CreateKernelWorkspaceAddress(const std::shared_ptr<OpRunt
|
|||
auto device_address = add_workspaces[i];
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
if (device_address->GetPtr() == nullptr &&
|
||||
!device_context->AllocateMemory(device_address.get(), device_address->GetSize())) {
|
||||
!device_context->device_res_manager_->AllocateMemory(device_address.get(), device_address->GetSize())) {
|
||||
MS_LOG(EXCEPTION) << "Allocate workspace memory failed";
|
||||
}
|
||||
workspaces.emplace_back(
|
||||
|
@ -474,7 +474,7 @@ void LaunchKernels(const KernelGraphPtr &graph, const device::DeviceContext *dev
|
|||
MS_LOG(EXCEPTION) << "Malloc for kernel output failed, Memory isn't enough, node:" << node->fullname_with_scope();
|
||||
}
|
||||
auto outputs = CreateKernelOutputAddress(runtime_info);
|
||||
if (!device_context->LaunchKernel(node, inputs, workspaces, outputs)) {
|
||||
if (!device_context->kernel_executor_->LaunchKernel(node, inputs, workspaces, outputs)) {
|
||||
MS_LOG(EXCEPTION) << "Launch kernel failed, name:" << node->fullname_with_scope();
|
||||
}
|
||||
|
||||
|
|
|
@ -159,6 +159,10 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
"../../../mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_memory_pool.cc"
|
||||
"../../../mindspore/ccsrc/plugin/device/ascend/hal/device/lic_manager.cc"
|
||||
"../../../mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_device_context.cc"
|
||||
"../../../mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_device_res_manager.cc"
|
||||
"../../../mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_kernel_executor.cc"
|
||||
"../../../mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_graph_executor.cc"
|
||||
"../../../mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_utils.cc"
|
||||
"../../../mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_graph_optimization.cc"
|
||||
"../../../mindspore/ccsrc/plugin/device/cpu/hal/hardware/ms_collective_topo.cc"
|
||||
"../../../mindspore/ccsrc/plugin/device/cpu/kernel/cpu_kernel.cc"
|
||||
|
|
|
@ -60,11 +60,11 @@ class TestDeviceAddress : public DeviceAddress {
|
|||
virtual void ClearDeviceMemory() {}
|
||||
};
|
||||
|
||||
class TestDeviceContext : public DeviceContext {
|
||||
class TestDeviceResManager : public device::DeviceResManager {
|
||||
public:
|
||||
explicit TestDeviceContext(const DeviceContextKey &device_context_key) : DeviceContext(device_context_key) {}
|
||||
~TestDeviceContext() override = default;
|
||||
virtual void Initialize() {}
|
||||
TestDeviceResManager() = default;
|
||||
~TestDeviceResManager() override = default;
|
||||
|
||||
virtual bool AllocateMemory(DeviceAddress *const &address, size_t size) const { return true; }
|
||||
virtual void FreeMemory(DeviceAddress *const &address) const {}
|
||||
virtual void *AllocateMemory(size_t size) const { return nullptr; }
|
||||
|
@ -73,9 +73,23 @@ class TestDeviceContext : public DeviceContext {
|
|||
TypeId type_id, const ShapeVector &shape) const {
|
||||
return std::make_shared<TestDeviceAddress>(nullptr, 0);
|
||||
}
|
||||
};
|
||||
|
||||
class TestKernelExecutor : public device::KernelExecutor {
|
||||
public:
|
||||
TestKernelExecutor() = default;
|
||||
~TestKernelExecutor() override = default;
|
||||
};
|
||||
|
||||
class TestDeviceContext : public device::DeviceInterface<TestKernelExecutor, TestDeviceResManager> {
|
||||
public:
|
||||
explicit TestDeviceContext(const DeviceContextKey &device_context_key)
|
||||
: DeviceInterface(device_context_key) {}
|
||||
~TestDeviceContext() override = default;
|
||||
|
||||
virtual void Initialize() {}
|
||||
virtual DeviceType GetDeviceType() const { return DeviceType::kCPU; }
|
||||
virtual void SetOperatorInfo(const KernelGraphPtr &graph) const {}
|
||||
virtual void CreateKernel(const std::vector<CNodePtr> &nodes) const {}
|
||||
device::RunMode GetRunMode(const FuncGraphPtr &func_graph) const override { return device::RunMode::kKernelMode; }
|
||||
};
|
||||
|
||||
KernelGraphPtr BuildKernelGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &front_node,
|
||||
|
|
|
@ -1,22 +1,23 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "common/common_test.h"
|
||||
#include "abstract/abstract_function.h"
|
||||
#include "runtime/graph_scheduler/graph_compiler.h"
|
||||
#include "runtime/hardware/device_context.h"
|
||||
#include "kernel/kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -31,127 +32,142 @@ using DeviceType = device::DeviceType;
|
|||
using AddressPtr = kernel::AddressPtr;
|
||||
|
||||
class TestDeviceAddress : public DeviceAddress {
|
||||
public:
|
||||
TestDeviceAddress(void *ptr, size_t size) : DeviceAddress(ptr, size) {}
|
||||
~TestDeviceAddress() {}
|
||||
virtual bool SyncDeviceToHost(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr) const {
|
||||
return true;
|
||||
}
|
||||
virtual bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr,
|
||||
const std::string &format) const {
|
||||
return true;
|
||||
}
|
||||
virtual void *GetMutablePtr() const { return nullptr; }
|
||||
virtual void ClearDeviceMemory() {}
|
||||
public:
|
||||
TestDeviceAddress(void *ptr, size_t size) : DeviceAddress(ptr, size) {}
|
||||
~TestDeviceAddress() {}
|
||||
virtual bool SyncDeviceToHost(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr) const {
|
||||
return true;
|
||||
}
|
||||
virtual bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr,
|
||||
const std::string &format) const {
|
||||
return true;
|
||||
}
|
||||
virtual void *GetMutablePtr() const { return nullptr; }
|
||||
virtual void ClearDeviceMemory() {}
|
||||
};
|
||||
|
||||
class TestKernelMod : public kernel::KernelMod {
|
||||
public:
|
||||
TestKernelMod() = default;
|
||||
~TestKernelMod() override = default;
|
||||
virtual bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
return true;
|
||||
}
|
||||
public:
|
||||
TestKernelMod() = default;
|
||||
~TestKernelMod() override = default;
|
||||
virtual bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
class TestADeviceContext : public DeviceContext {
|
||||
public:
|
||||
explicit TestADeviceContext(const DeviceContextKey &device_context_key) : DeviceContext(device_context_key) {}
|
||||
~TestADeviceContext() override = default;
|
||||
virtual void Initialize() {}
|
||||
virtual bool AllocateMemory(DeviceAddress *const &address, size_t size) const { return true; }
|
||||
virtual void FreeMemory(DeviceAddress *const &address) const {}
|
||||
virtual void *AllocateMemory(size_t size) const { return nullptr; }
|
||||
virtual void FreeMemory(void *const ptr) const {}
|
||||
virtual DeviceAddressPtr CreateDeviceAddress(void *const device_ptr, size_t device_size, const string &format,
|
||||
TypeId type_id, const ShapeVector &shape) const {
|
||||
return std::make_shared<TestDeviceAddress>(nullptr, 0);
|
||||
}
|
||||
virtual DeviceType GetDeviceType() const { return DeviceType::kCPU; }
|
||||
virtual void SetOperatorInfo(const KernelGraphPtr &graph) const {}
|
||||
virtual void CreateKernel(const std::vector<CNodePtr> &nodes) const {
|
||||
for (const auto node : nodes) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->kernel_info() == nullptr) {
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
std::shared_ptr<KernelBuildInfoBuilder> builder = std::make_shared<KernelBuildInfoBuilder>();
|
||||
kernel_info->set_select_kernel_build_info(builder->Build());
|
||||
node->set_kernel_info(kernel_info);
|
||||
} else {
|
||||
const auto &kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
|
||||
if (kernel_info->select_kernel_build_info() == nullptr) {
|
||||
std::shared_ptr<KernelBuildInfoBuilder> builder = std::make_shared<KernelBuildInfoBuilder>();
|
||||
kernel_info->set_select_kernel_build_info(builder->Build());
|
||||
}
|
||||
}
|
||||
auto kernel_mod_ptr = std::make_shared<TestKernelMod>();
|
||||
kernel_mod_ptr->SetInputSizeList({4});
|
||||
kernel_mod_ptr->SetOutputSizeList({4});
|
||||
kernel_mod_ptr->SetWorkspaceSizeList({4});
|
||||
AnfAlgo::SetKernelMod(kernel_mod_ptr, node.get());
|
||||
}
|
||||
}
|
||||
class TestADeviceResManager : public device::DeviceResManager {
|
||||
public:
|
||||
TestADeviceResManager() = default;
|
||||
~TestADeviceResManager() override = default;
|
||||
|
||||
virtual bool AllocateMemory(DeviceAddress *const &address, size_t size) const { return true; }
|
||||
virtual void FreeMemory(DeviceAddress *const &address) const {}
|
||||
virtual void *AllocateMemory(size_t size) const { return nullptr; }
|
||||
virtual void FreeMemory(void *const ptr) const {}
|
||||
virtual DeviceAddressPtr CreateDeviceAddress(void *const device_ptr, size_t device_size, const string &format,
|
||||
TypeId type_id, const ShapeVector &shape) const {
|
||||
return std::make_shared<TestDeviceAddress>(nullptr, 0);
|
||||
}
|
||||
};
|
||||
|
||||
class TestAKernelExecutor : public device::KernelExecutor {
|
||||
public:
|
||||
TestAKernelExecutor() = default;
|
||||
~TestAKernelExecutor() override = default;
|
||||
virtual void CreateKernel(const std::vector<CNodePtr> &nodes) const {
|
||||
for (const auto node : nodes) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->kernel_info() == nullptr) {
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
std::shared_ptr<KernelBuildInfoBuilder> builder = std::make_shared<KernelBuildInfoBuilder>();
|
||||
kernel_info->set_select_kernel_build_info(builder->Build());
|
||||
node->set_kernel_info(kernel_info);
|
||||
} else {
|
||||
const auto &kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
|
||||
if (kernel_info->select_kernel_build_info() == nullptr) {
|
||||
std::shared_ptr<KernelBuildInfoBuilder> builder = std::make_shared<KernelBuildInfoBuilder>();
|
||||
kernel_info->set_select_kernel_build_info(builder->Build());
|
||||
}
|
||||
}
|
||||
auto kernel_mod_ptr = std::make_shared<TestKernelMod>();
|
||||
kernel_mod_ptr->SetInputSizeList({4});
|
||||
kernel_mod_ptr->SetOutputSizeList({4});
|
||||
kernel_mod_ptr->SetWorkspaceSizeList({4});
|
||||
AnfAlgo::SetKernelMod(kernel_mod_ptr, node.get());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class TestADeviceContext : public device::DeviceInterface<TestAKernelExecutor, TestADeviceResManager> {
|
||||
public:
|
||||
explicit TestADeviceContext(const DeviceContextKey &device_context_key)
|
||||
: DeviceInterface(device_context_key) {}
|
||||
~TestADeviceContext() override = default;
|
||||
|
||||
virtual void Initialize() {}
|
||||
virtual DeviceType GetDeviceType() const { return DeviceType::kCPU; }
|
||||
device::RunMode GetRunMode(const FuncGraphPtr &func_graph) const override { return device::RunMode::kKernelMode; }
|
||||
};
|
||||
|
||||
class GraphCompilerTest : public UT::Common {
|
||||
public:
|
||||
GraphCompilerTest() {}
|
||||
public:
|
||||
GraphCompilerTest() {}
|
||||
};
|
||||
|
||||
/// Feature: control flow support dynamic shape.
|
||||
/// Description: Test the parse interface.
|
||||
/// Expectation: As expected.
|
||||
TEST_F(GraphCompilerTest, CompileGraph) {
|
||||
std::vector<int64_t> shp{2, 2};
|
||||
abstract::AbstractTensorPtr abs;
|
||||
std::vector<int64_t> shp{2, 2};
|
||||
abstract::AbstractTensorPtr abs;
|
||||
|
||||
// Func graph.
|
||||
auto func_graph = std::make_shared<FuncGraph>();
|
||||
// Func graph.
|
||||
auto func_graph = std::make_shared<FuncGraph>();
|
||||
|
||||
// Parameter.
|
||||
auto abstract_x = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
auto parameter_x = func_graph->add_parameter();
|
||||
parameter_x->set_abstract(abstract_x);
|
||||
// Parameter.
|
||||
auto abstract_x = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
auto parameter_x = func_graph->add_parameter();
|
||||
parameter_x->set_abstract(abstract_x);
|
||||
|
||||
auto abstract_y = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
auto parameter_y = func_graph->add_parameter();
|
||||
parameter_y->set_abstract(abstract_y);
|
||||
auto parameters = func_graph->parameters();
|
||||
auto abstract_y = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
auto parameter_y = func_graph->add_parameter();
|
||||
parameter_y->set_abstract(abstract_y);
|
||||
auto parameters = func_graph->parameters();
|
||||
|
||||
// Add.
|
||||
std::vector<AnfNodePtr> add_inputs{NewValueNode(prim::kPrimAdd), parameters[0], parameters[1]};
|
||||
auto add_node = func_graph->NewCNode(add_inputs);
|
||||
abs = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
add_node->set_abstract(abs);
|
||||
// Add.
|
||||
std::vector<AnfNodePtr> add_inputs{NewValueNode(prim::kPrimAdd), parameters[0], parameters[1]};
|
||||
auto add_node = func_graph->NewCNode(add_inputs);
|
||||
abs = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
add_node->set_abstract(abs);
|
||||
|
||||
// Reshape.
|
||||
std::vector<AnfNodePtr> reshape_inputs{NewValueNode(prim::kPrimReshape), add_node};
|
||||
auto reshape_node = func_graph->NewCNode(reshape_inputs);
|
||||
abs = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
reshape_node->set_abstract(abs);
|
||||
// Reshape.
|
||||
std::vector<AnfNodePtr> reshape_inputs{NewValueNode(prim::kPrimReshape), add_node};
|
||||
auto reshape_node = func_graph->NewCNode(reshape_inputs);
|
||||
abs = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
reshape_node->set_abstract(abs);
|
||||
|
||||
// sub.
|
||||
std::vector<AnfNodePtr> sub_inputs{NewValueNode(prim::kPrimSub), reshape_node, parameters[0]};
|
||||
auto sub_node = func_graph->NewCNode(sub_inputs);
|
||||
abs = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
sub_node->set_abstract(abs);
|
||||
// sub.
|
||||
std::vector<AnfNodePtr> sub_inputs{NewValueNode(prim::kPrimSub), reshape_node, parameters[0]};
|
||||
auto sub_node = func_graph->NewCNode(sub_inputs);
|
||||
abs = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
sub_node->set_abstract(abs);
|
||||
|
||||
// Return.
|
||||
std::vector<AnfNodePtr> return_inputs{NewValueNode(prim::kPrimReturn), sub_node};
|
||||
auto return_node = func_graph->NewCNode(return_inputs);
|
||||
func_graph->set_return(return_node);
|
||||
// Return.
|
||||
std::vector<AnfNodePtr> return_inputs{NewValueNode(prim::kPrimReturn), sub_node};
|
||||
auto return_node = func_graph->NewCNode(return_inputs);
|
||||
func_graph->set_return(return_node);
|
||||
|
||||
std::vector<AnfNodePtr> nodes{add_node, reshape_node, sub_node};
|
||||
std::vector<AnfNodePtr> outputs{sub_node};
|
||||
auto segment = std::make_shared<GraphSegment>(nodes, false);
|
||||
std::vector<AnfNodePtr> nodes{add_node, reshape_node, sub_node};
|
||||
std::vector<AnfNodePtr> outputs{sub_node};
|
||||
auto segment = std::make_shared<GraphSegment>(nodes, false);
|
||||
|
||||
auto compiler = std::make_shared<GraphCompiler>();
|
||||
DeviceContextKey device_context_key{"CPU", 0};
|
||||
auto device_context = std::make_shared<TestADeviceContext>(device_context_key);
|
||||
auto graph_id = compiler->CompileGraph(segment, outputs, device_context.get(), device::RunMode::kKernelMode, false);
|
||||
const auto &kernel_graph = compiler->Fetch(graph_id);
|
||||
ASSERT_EQ(3, kernel_graph->execution_order().size());
|
||||
auto compiler = std::make_shared<GraphCompiler>();
|
||||
DeviceContextKey device_context_key{"CPU", 0};
|
||||
auto device_context = std::make_shared<TestADeviceContext>(device_context_key);
|
||||
auto graph_id = compiler->CompileGraph(segment, outputs, device_context.get(), device::RunMode::kKernelMode, false);
|
||||
const auto &kernel_graph = compiler->Fetch(graph_id);
|
||||
ASSERT_EQ(3, kernel_graph->execution_order().size());
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue