forked from mindspore-Ecosystem/mindspore
!22092 [executor] Add mem scheduler
Merge pull request !22092 from kisnwang/add-mem-extend-cache
This commit is contained in:
commit
6f09891501
|
@ -43,11 +43,8 @@ class AscendKernelMod : public KernelMod {
|
|||
return false;
|
||||
#endif
|
||||
}
|
||||
void SetStream(void *stream) { stream_ = stream; }
|
||||
void *GetStream() { return stream_; }
|
||||
|
||||
protected:
|
||||
void *stream_{nullptr};
|
||||
uint32_t block_dim_{1};
|
||||
uint32_t stream_id_{0};
|
||||
};
|
||||
|
|
|
@ -195,11 +195,14 @@ class KernelMod {
|
|||
const std::vector<AddressPtr> &GetInputsAddr() { return inputs_addr_; }
|
||||
const std::vector<AddressPtr> &GetWorkSpacesAddr() { return workspaces_addr_; }
|
||||
const std::vector<AddressPtr> &GetOutputsAddr() { return outputs_addr_; }
|
||||
void SetStream(void *stream) { stream_ = stream; }
|
||||
void *GetStream() const { return stream_; }
|
||||
|
||||
protected:
|
||||
std::string unique_name_;
|
||||
std::string fullname_;
|
||||
bool is_monad_{false};
|
||||
void *stream_{nullptr};
|
||||
|
||||
private:
|
||||
std::vector<AddressPtr> inputs_addr_;
|
||||
|
|
|
@ -45,18 +45,6 @@ using DeviceAddress = device::DeviceAddress;
|
|||
using DeviceAddressPtr = device::DeviceAddressPtr;
|
||||
using Address = kernel::Address;
|
||||
using AddressPtr = kernel::AddressPtr;
|
||||
using KernelWithIndex = std::pair<AnfNodePtr, size_t>;
|
||||
struct KernelWithIndexCmp {
|
||||
bool operator()(const KernelWithIndex &key1, const KernelWithIndex &key2) const {
|
||||
if (key1.first != key2.first) {
|
||||
return key1.first < key2.first;
|
||||
}
|
||||
if (key1.second != key2.second) {
|
||||
return key1.second < key2.second;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
class OpRuntimeInfo {
|
||||
public:
|
||||
|
|
|
@ -413,13 +413,18 @@ void AscendSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_gra
|
|||
MS_LOG(EXCEPTION) << "Tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size()
|
||||
<< ", input_ctrl_size:" << input_ctrl_size;
|
||||
}
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
auto enable_mem_scheduler = ms_context->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER);
|
||||
if (enable_mem_scheduler) {
|
||||
kernel_graph->SetInputTensors(inputs);
|
||||
return;
|
||||
}
|
||||
for (auto item : tensor_device_addr_map_) {
|
||||
auto output_tensor = item.first;
|
||||
output_tensor->set_device_address(item.second);
|
||||
}
|
||||
SyncStream();
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
auto tensor = inputs[i];
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
|
@ -655,7 +660,10 @@ void AscendSession::BuildGraphImpl(GraphId graph_id) {
|
|||
} else {
|
||||
// alloc memory, including static memory and dynamic memory
|
||||
MemoryAlloc(graph.get());
|
||||
AnfAlgo::CacheAddrForGraph(graph);
|
||||
auto enable_mem_scheduler = ms_context->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER);
|
||||
if (!enable_mem_scheduler) {
|
||||
AnfAlgo::CacheAddrForGraph(graph);
|
||||
}
|
||||
// generate and load task info to device if it is sink mode
|
||||
Load(graph);
|
||||
}
|
||||
|
@ -690,10 +698,13 @@ void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) {
|
|||
// optimize graph
|
||||
HardwareOptimize(child_graph);
|
||||
// assign static memory of parameters
|
||||
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance);
|
||||
runtime_instance->AssignStaticMemoryInput(child_graph.get());
|
||||
runtime_instance->AssignStaticMemoryValueNode(child_graph.get());
|
||||
auto enable_mem_scheduler = context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER);
|
||||
if (!enable_mem_scheduler) {
|
||||
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance);
|
||||
runtime_instance->AssignStaticMemoryInput(child_graph.get());
|
||||
runtime_instance->AssignStaticMemoryValueNode(child_graph.get());
|
||||
}
|
||||
}
|
||||
|
||||
bool AscendSession::IsSupportSummary() { return !device::KernelAdjust::NeedInsertSwitch(); }
|
||||
|
@ -1954,6 +1965,12 @@ void AscendSession::ExecuteAllTaskInQueue() {
|
|||
void AscendSession::UpdateOutputTensors(const VectorRef *outputs,
|
||||
const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node,
|
||||
std::map<DeviceAddressPtr, DeviceAddressPtr> *) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
auto enable_mem_scheduler = context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER);
|
||||
if (enable_mem_scheduler) {
|
||||
return;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
tensor_device_addr_map_.clear();
|
||||
for (const auto &item : *outputs) {
|
||||
|
|
|
@ -36,6 +36,21 @@
|
|||
namespace mindspore {
|
||||
namespace session {
|
||||
using AnfWithOutIndex = std::pair<AnfNodePtr, size_t>;
|
||||
using KernelWithIndex = std::pair<AnfNodePtr, size_t>;
|
||||
struct KernelWithIndexCmp {
|
||||
bool operator()(const KernelWithIndex &key1, const KernelWithIndex &key2) const {
|
||||
if (key1.first != key2.first) {
|
||||
return key1.first < key2.first;
|
||||
}
|
||||
if (key1.second != key2.second) {
|
||||
return key1.second < key2.second;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
using KernelMapTensor = std::map<session::KernelWithIndex, BaseRef, session::KernelWithIndexCmp>;
|
||||
|
||||
class KernelGraph : public FuncGraph {
|
||||
public:
|
||||
KernelGraph() : graph_id_(0), start_label_(nullptr), end_goto_(nullptr), current_epoch_(0), is_dynamic_shape_(false) {
|
||||
|
@ -260,8 +275,21 @@ class KernelGraph : public FuncGraph {
|
|||
void SetOptimizerFlag();
|
||||
void SetInputNodes();
|
||||
const std::vector<AnfNodePtr> &input_nodes() const { return input_nodes_; }
|
||||
void SetInputTensors(const std::vector<tensor::TensorPtr> &input_tensors) { input_tensors_ = input_tensors; }
|
||||
const std::vector<tensor::TensorPtr> &input_tensors() const { return input_tensors_; }
|
||||
|
||||
void SetOutputNodeToTensor(const KernelMapTensor &node_to_tensor) { output_node_to_tensor_ = node_to_tensor; }
|
||||
|
||||
tensor::TensorPtr GetNodeOutputTensor(const session::KernelWithIndex &output_index) const {
|
||||
auto iter = output_node_to_tensor_.find(output_index);
|
||||
if (iter != output_node_to_tensor_.end()) {
|
||||
return utils::cast<tensor::TensorPtr>(iter->second);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool has_optimizer() const { return has_optimizer_; }
|
||||
bool IsUpdatedParameter(const ParameterPtr ¶m) {
|
||||
bool IsUpdatedParameter(const ParameterPtr ¶m) const {
|
||||
if (updated_parameters_.find(param) != updated_parameters_.end()) {
|
||||
return true;
|
||||
}
|
||||
|
@ -428,6 +456,8 @@ class KernelGraph : public FuncGraph {
|
|||
std::map<AnfNodePtr, AnfNodePtr> edge_to_;
|
||||
std::stack<AnfNodePtr> loop_nodes_;
|
||||
std::vector<AnfNodePtr> input_nodes_;
|
||||
std::vector<tensor::TensorPtr> input_tensors_;
|
||||
KernelMapTensor output_node_to_tensor_;
|
||||
std::unordered_map<uint32_t, std::weak_ptr<session::KernelGraph>> pre_graphs_;
|
||||
std::unordered_map<uint32_t, std::weak_ptr<session::KernelGraph>> post_graphs_;
|
||||
// The send/recv pairs inserted for allreduce, the key is allreduce kernel, the first of pair is send node, the second
|
||||
|
|
|
@ -1689,11 +1689,23 @@ void SessionBasic::CreateOutputTensors(const GraphId &graph_id, const std::vecto
|
|||
MS_LOG(INFO) << "Create node output[" << item->DebugString() << "]";
|
||||
outputs->emplace_back(CreateNodeOutputTensors(item, kernel_graph, input_tensors, tensor_to_node, &node_to_tensor));
|
||||
}
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
auto enable_mem_scheduler = ms_context->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER);
|
||||
if (enable_mem_scheduler) {
|
||||
kernel_graph->SetOutputNodeToTensor(node_to_tensor);
|
||||
}
|
||||
}
|
||||
|
||||
void SessionBasic::UpdateOutputTensors(const VectorRef *outputs,
|
||||
const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node,
|
||||
std::map<DeviceAddressPtr, DeviceAddressPtr> *) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
auto enable_mem_scheduler = context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER);
|
||||
if (enable_mem_scheduler) {
|
||||
return;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
for (const auto &item : *outputs) {
|
||||
if (utils::isa<VectorRefPtr>(item)) {
|
||||
|
|
|
@ -88,7 +88,6 @@ struct GraphOutputInfo {
|
|||
std::vector<tensor::TensorPtr> graph_output_tensors;
|
||||
};
|
||||
|
||||
using KernelMapTensor = std::map<session::KernelWithIndex, BaseRef, session::KernelWithIndexCmp>;
|
||||
class Executor;
|
||||
|
||||
class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
||||
|
|
|
@ -1267,8 +1267,9 @@ void InitHccl() {
|
|||
uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
#if ENABLE_D
|
||||
bool task_sink = true;
|
||||
auto single_op = std::getenv(kGraphOpRun);
|
||||
if (single_op && std::string(single_op) == "1") {
|
||||
auto single_op = common::GetEnv(kGraphOpRun);
|
||||
auto enable_mem_scheduler = common::GetEnv(kEnableMemScheduler);
|
||||
if (single_op == "1" || enable_mem_scheduler == "1") {
|
||||
task_sink = false;
|
||||
}
|
||||
auto mode = ms_context->get_param<int>(MS_CTX_EXECUTION_MODE);
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
file(GLOB_RECURSE DEVICE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "common/*.cc"
|
||||
"kernel_info.cc" "executor/dynamic_kernel.cc" "executor/executor_callback.cc" "kernel_runtime.cc"
|
||||
"memory_manager.cc" "kernel_runtime_manager.cc" "convert_tensor_utils.cc"
|
||||
"memory_manager.cc" "kernel_runtime_manager.cc" "convert_tensor_utils.cc" "memory_scheduler.cc"
|
||||
"bucket.cc" "launch_kernel.cc" "launch_mul.cc"
|
||||
)
|
||||
|
||||
|
|
|
@ -173,6 +173,10 @@ void AscendKernelRuntime::ClearGraphModelMap() {
|
|||
|
||||
void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id) {
|
||||
SetCurrentContext();
|
||||
auto mem_scheduler = mem_scheduler_manager_.GetMemScheduler(graph_id);
|
||||
if (mem_scheduler != nullptr) {
|
||||
mem_scheduler->Clear();
|
||||
}
|
||||
MS_LOG(DEBUG) << "Clear graph:" << graph_id << " data dumper";
|
||||
#ifndef ENABLE_SECURITY
|
||||
if (auto dumper_iter = graph_data_dumper_.find(graph_id); dumper_iter != graph_data_dumper_.end()) {
|
||||
|
@ -725,25 +729,6 @@ bool AscendKernelRuntime::Run(session::KernelGraph *const graph, bool is_task_si
|
|||
return ret;
|
||||
}
|
||||
|
||||
bool AscendKernelRuntime::LaunchKernel(const AnfNodePtr &kernel) {
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
AddressPtrList kernel_inputs = kernel_mod->GetInputsAddr();
|
||||
AddressPtrList kernel_workspaces = kernel_mod->GetWorkSpacesAddr();
|
||||
AddressPtrList kernel_outputs = kernel_mod->GetOutputsAddr();
|
||||
bool ret;
|
||||
if (pynative_mode_profiling_flag_) {
|
||||
auto ascend_kernel_mod = dynamic_cast<kernel::AscendKernelMod *>(kernel_mod);
|
||||
MS_EXCEPTION_IF_NULL(ascend_kernel_mod);
|
||||
auto stream = ascend_kernel_mod->GetStream();
|
||||
ret = LaunchKernelWithPynativeProfiling(kernel_mod, kernel->fullname_with_scope(), kernel_inputs, kernel_workspaces,
|
||||
kernel_outputs, stream);
|
||||
} else {
|
||||
ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
void AscendKernelRuntime::SetKernelModStream(const std::vector<CNodePtr> &kernels,
|
||||
std::vector<size_t> *last_stream_nodes) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
|
|
|
@ -48,7 +48,6 @@ class AscendKernelRuntime : public KernelRuntime {
|
|||
void ProcessBoundaryEvent(const std::vector<CNodePtr> &kernels,
|
||||
std::vector<std::vector<std::function<void()>>> *kernel_run_events,
|
||||
const std::vector<size_t> &last_stream_nodes);
|
||||
bool LaunchKernel(const AnfNodePtr &kernel) override;
|
||||
bool GenDynamicKernel(const session::KernelGraph *graph) override;
|
||||
bool RunDynamicKernelAsync(const session::KernelGraph *graph) override;
|
||||
bool LoadTask(const session::KernelGraph *graph);
|
||||
|
|
|
@ -128,6 +128,10 @@ void *AscendMemoryManager::MallocMemFromMemPool(size_t size) {
|
|||
return AscendMemoryPool::GetInstance().AllocTensorMem(align_size);
|
||||
}
|
||||
|
||||
void AscendMemoryManager::FreeMemFromMemPool(void *device_ptr) {
|
||||
AscendMemoryPool::GetInstance().FreeTensorMem(device_ptr);
|
||||
}
|
||||
|
||||
uint8_t *AscendMemoryManager::MallocStaticMem(size_t size, bool communication_mem, uint32_t graph_id) {
|
||||
size_t align_size = 0;
|
||||
if (communication_mem) {
|
||||
|
@ -209,6 +213,47 @@ uint8_t *AscendMemoryManager::MallocCommunicationMemFromMemPool(size_t size) {
|
|||
uint8_t *base_ptr = reinterpret_cast<uint8_t *>(AscendMemoryPool::GetInstance().AllocTensorMem(align_size));
|
||||
return base_ptr + kMemAlignSize;
|
||||
}
|
||||
|
||||
size_t AscendMemoryManager::GetAvailableMemSize() {
|
||||
auto available_mem_size = AscendMemoryPool::GetInstance().free_mem_size() +
|
||||
AscendMemoryPool::GetInstance().total_mem_statistics() -
|
||||
AscendMemoryPool::GetInstance().used_mem_statistics();
|
||||
return available_mem_size;
|
||||
}
|
||||
|
||||
void AscendMemoryManager::SwapIn(void *host_ptr, void *device_ptr, size_t mem_size, void *stream) {
|
||||
if (stream == nullptr) {
|
||||
auto ret_rt_memcpy = rtMemcpy(device_ptr, mem_size, host_ptr, mem_size, RT_MEMCPY_HOST_TO_DEVICE);
|
||||
if (ret_rt_memcpy != RT_ERROR_NONE) {
|
||||
MS_EXCEPTION(DeviceProcessError) << "SwapIn rtMemcpy failed.";
|
||||
}
|
||||
} else {
|
||||
auto ret_rt_memcpy = rtMemcpyAsync(device_ptr, mem_size, host_ptr, mem_size, RT_MEMCPY_HOST_TO_DEVICE, stream);
|
||||
if (ret_rt_memcpy != RT_ERROR_NONE) {
|
||||
MS_EXCEPTION(DeviceProcessError) << "SwapIn rtMemcpyAsync failed.";
|
||||
}
|
||||
if (rtStreamSynchronize(stream) != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AscendMemoryManager::SwapOut(void *device_ptr, void *host_ptr, size_t mem_size, void *stream) {
|
||||
if (stream == nullptr) {
|
||||
auto ret_rt_memcpy = rtMemcpy(host_ptr, mem_size, device_ptr, mem_size, RT_MEMCPY_DEVICE_TO_HOST);
|
||||
if (ret_rt_memcpy != RT_ERROR_NONE) {
|
||||
MS_EXCEPTION(DeviceProcessError) << "SwapOut rtMemcpy failed.";
|
||||
}
|
||||
} else {
|
||||
auto ret_rt_memcpy = rtMemcpyAsync(host_ptr, mem_size, device_ptr, mem_size, RT_MEMCPY_DEVICE_TO_HOST, stream);
|
||||
if (ret_rt_memcpy != RT_ERROR_NONE) {
|
||||
MS_EXCEPTION(DeviceProcessError) << "SwapOut rtMemcpyAsync failed.";
|
||||
}
|
||||
if (rtStreamSynchronize(stream) != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error.";
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -33,6 +33,7 @@ class AscendMemoryManager : public MemoryManager {
|
|||
void ResetDynamicMemory() override;
|
||||
void ClearGlobalIdleMem() override;
|
||||
void *MallocMemFromMemPool(size_t size) override;
|
||||
void FreeMemFromMemPool(void *device_ptr) override;
|
||||
uint64_t GetDeviceMemSize();
|
||||
void MallocSomasDynamicMem(const session::KernelGraph *graph) override;
|
||||
uint8_t *MallocCommunicationMemFromMemPool(size_t size) override;
|
||||
|
@ -40,6 +41,10 @@ class AscendMemoryManager : public MemoryManager {
|
|||
return AscendMemoryPool::GetInstance().AllocContinuousTensorMem(total_size, size_list);
|
||||
}
|
||||
|
||||
void SwapIn(void *host_ptr, void *device_ptr, size_t mem_size, void *stream) override;
|
||||
void SwapOut(void *device_ptr, void *host_ptr, size_t mem_size, void *stream) override;
|
||||
size_t GetAvailableMemSize() override;
|
||||
|
||||
protected:
|
||||
uint8_t *MallocStaticMem(size_t size, bool communication_mem, uint32_t graph_id = kInvalidGraphId) override;
|
||||
uint8_t *MallocDynamicMem(size_t size, bool communication_mem) override;
|
||||
|
|
|
@ -42,6 +42,9 @@ using mindspore::kernel::AddressPtr;
|
|||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
constexpr float kMaxMemReuseFactor = 0.8;
|
||||
constexpr float kMinMemReuseFactor = 0.5;
|
||||
constexpr float kRetryFactor = 0.1;
|
||||
namespace {
|
||||
std::vector<AnfNodePtr> GetGraphInputs(const session::KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
@ -85,10 +88,16 @@ bool KernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_
|
|||
void KernelRuntime::AssignMemory(session::KernelGraph *graph) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||
mem_manager_->ResetDynamicMemory();
|
||||
AssignStaticMemory(graph);
|
||||
AssignDynamicMemory(graph);
|
||||
auto enable_mem_scheduler = context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER);
|
||||
if (enable_mem_scheduler) {
|
||||
AssignStaticMemoryValueNode(graph);
|
||||
ResetNodeAddress(graph);
|
||||
} else {
|
||||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||
mem_manager_->ResetDynamicMemory();
|
||||
AssignStaticMemory(graph);
|
||||
AssignDynamicMemory(graph);
|
||||
}
|
||||
UpdateRefNodeOutputMem(graph);
|
||||
}
|
||||
|
||||
|
@ -253,6 +262,47 @@ void KernelRuntime::RunOpMallocPre(const session::KernelGraph &graph,
|
|||
}
|
||||
}
|
||||
|
||||
void KernelRuntime::ResetNodeAddress(session::KernelGraph *kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto kernels = kernel_graph->execution_order();
|
||||
for (auto &kernel : kernels) {
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel);
|
||||
for (size_t j = 0; j < input_num; ++j) {
|
||||
auto input_index = AnfAlgo::GetRealInputIndex(kernel, j);
|
||||
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(kernel, input_index, true);
|
||||
auto index = kernel_with_index.second;
|
||||
auto &input_node = kernel_with_index.first;
|
||||
if (NodeOutputDeviceAddressExist(input_node, index)) {
|
||||
continue;
|
||||
}
|
||||
TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(input_node, index);
|
||||
if (output_type_id == kTypeUnknown) {
|
||||
MS_LOG(WARNING) << "It is not suggested to use a lonely weight parameter as the output of graph";
|
||||
continue;
|
||||
}
|
||||
auto tensor_size = AnfAlgo::GetOutputTensorMemSize(input_node, index);
|
||||
auto device_address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(input_node, index),
|
||||
output_type_id, {input_node, index});
|
||||
AnfAlgo::SetOutputAddr(device_address, index, input_node.get());
|
||||
}
|
||||
|
||||
auto output_sizes = kernel_mod->GetOutputSizeList();
|
||||
for (size_t i = 0; i < output_sizes.size(); ++i) {
|
||||
auto output_format = AnfAlgo::GetOutputFormat(kernel, i);
|
||||
auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
|
||||
AnfAlgo::SetOutputAddr(CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type), i,
|
||||
kernel.get());
|
||||
}
|
||||
auto workspace_sizes = kernel_mod->GetWorkspaceSizeList();
|
||||
for (size_t i = 0; i < workspace_sizes.size(); ++i) {
|
||||
AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(nullptr, workspace_sizes[i], kOpFormat_DEFAULT, kNumberTypeFloat32),
|
||||
i, kernel.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void KernelRuntime::RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors, session::KernelGraph *graph,
|
||||
const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
@ -1125,7 +1175,8 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod
|
|||
}
|
||||
}
|
||||
|
||||
void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs) {
|
||||
void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs,
|
||||
const std::shared_ptr<MemScheduler> &mem_scheduler) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(kernel_inputs);
|
||||
if (cnode->inputs().size() != 2) {
|
||||
|
@ -1144,8 +1195,12 @@ void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList
|
|||
auto device_address = AnfAlgo::GetOutputAddr(pre_node, index);
|
||||
kernel::AddressPtr input = std::make_shared<kernel::Address>();
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
input->addr = device_address->ptr_;
|
||||
MS_EXCEPTION_IF_NULL(input->addr);
|
||||
if (mem_scheduler != nullptr) {
|
||||
GetOrMallocAddress(mem_scheduler, device_address, input);
|
||||
} else {
|
||||
input->addr = device_address->ptr_;
|
||||
MS_EXCEPTION_IF_NULL(input->addr);
|
||||
}
|
||||
input->size = device_address->size_;
|
||||
kernel_inputs->emplace_back(input);
|
||||
}
|
||||
|
@ -1162,8 +1217,12 @@ void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList
|
|||
auto device_address = AnfAlgo::GetWorkspaceAddr(pre_node, index);
|
||||
kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
|
||||
MS_EXCEPTION_IF_NULL(workspace);
|
||||
workspace->addr = device_address->ptr_;
|
||||
MS_EXCEPTION_IF_NULL(workspace->addr);
|
||||
if (mem_scheduler != nullptr) {
|
||||
GetOrMallocAddress(mem_scheduler, device_address, workspace);
|
||||
} else {
|
||||
workspace->addr = device_address->ptr_;
|
||||
MS_EXCEPTION_IF_NULL(workspace->addr);
|
||||
}
|
||||
workspace->size = device_address->size_;
|
||||
kernel_inputs->emplace_back(workspace);
|
||||
}
|
||||
|
@ -1220,36 +1279,197 @@ void KernelRuntime::DebugStreamSync(const CNodePtr &kernel) {
|
|||
}
|
||||
}
|
||||
|
||||
bool KernelRuntime::LaunchKernel(const AnfNodePtr &kernel) {
|
||||
void KernelRuntime::GetOrMallocAddress(const std::shared_ptr<MemScheduler> &mem_scheduler,
|
||||
const DeviceAddress *device_address, const kernel::AddressPtr &kernel_addr) {
|
||||
if (device_address->ptr_ != nullptr) {
|
||||
kernel_addr->addr = device_address->ptr_;
|
||||
} else {
|
||||
kernel_addr->addr = mem_scheduler->GetOrMalloc(device_address, device_address->size_);
|
||||
if (mem_scheduler->IsHighPriorityMem(device_address)) {
|
||||
device_address->ptr_ = kernel_addr->addr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void KernelRuntime::AssignKernelAddress(const std::shared_ptr<MemScheduler> &mem_scheduler, const AnfNodePtr &kernel,
|
||||
AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces,
|
||||
AddressPtrList *kernel_outputs) {
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
MS_EXCEPTION_IF_NULL(kernel_inputs);
|
||||
MS_EXCEPTION_IF_NULL(kernel_workspaces);
|
||||
MS_EXCEPTION_IF_NULL(kernel_outputs);
|
||||
auto cnode = kernel->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::GetCNodeName(cnode) == kAtomicAddrCleanOpName) {
|
||||
return GenAddrCleanLaunchArgs(cnode, kernel_inputs, mem_scheduler);
|
||||
}
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel);
|
||||
for (size_t j = 0; j < input_num; ++j) {
|
||||
auto real_input = AnfAlgo::GetRealInputIndex(kernel, j);
|
||||
auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(kernel, real_input, true);
|
||||
auto index = kernel_with_index.second;
|
||||
auto &input_node = kernel_with_index.first;
|
||||
auto device_address = AnfAlgo::GetOutputAddr(input_node, index, true);
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
kernel::AddressPtr input = std::make_shared<kernel::Address>();
|
||||
GetOrMallocAddress(mem_scheduler, device_address, input);
|
||||
input->size = device_address->size_;
|
||||
kernel_inputs->emplace_back(input);
|
||||
}
|
||||
|
||||
for (size_t j = 0; j < kernel_mod->GetOutputSizeList().size(); ++j) {
|
||||
auto device_address = AnfAlgo::GetOutputAddr(kernel, j, true);
|
||||
kernel::AddressPtr output = std::make_shared<kernel::Address>();
|
||||
GetOrMallocAddress(mem_scheduler, device_address, output);
|
||||
output->size = device_address->size_;
|
||||
kernel_outputs->emplace_back(output);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) {
|
||||
auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i);
|
||||
kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
|
||||
GetOrMallocAddress(mem_scheduler, device_address, workspace);
|
||||
workspace->size = device_address->size_;
|
||||
kernel_workspaces->emplace_back(workspace);
|
||||
}
|
||||
}
|
||||
|
||||
void KernelRuntime::SyncNodeOutputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler,
|
||||
const session::KernelGraph *graph, const AnfNodePtr &kernel, bool mock) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(mem_scheduler);
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
for (size_t j = 0; j < kernel_mod->GetOutputSizeList().size(); ++j) {
|
||||
auto tensor = graph->GetNodeOutputTensor(std::make_pair(kernel, j));
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, j, true);
|
||||
if (mock) {
|
||||
if (graph->IsInternalOutput(kernel, j) && device_address != nullptr) {
|
||||
mem_scheduler->SetMemPriority(device_address.get(), kMemPriorityHigh);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (tensor != nullptr) {
|
||||
if (device_address == nullptr) {
|
||||
tensor->data_sync(false);
|
||||
tensor->set_device_address(nullptr);
|
||||
tensor->set_sync_status(kNeedSyncHostToDevice);
|
||||
continue;
|
||||
}
|
||||
SyncStream();
|
||||
auto origin_ptr = device_address->ptr_;
|
||||
if (origin_ptr == nullptr) {
|
||||
device_address->ptr_ = mem_scheduler->GetOrMalloc(device_address.get(), device_address->size_);
|
||||
}
|
||||
tensor->set_device_address(device_address);
|
||||
tensor->data_sync(false);
|
||||
tensor->set_device_address(nullptr);
|
||||
if (origin_ptr == nullptr) {
|
||||
device_address->ptr_ = nullptr;
|
||||
}
|
||||
tensor->set_sync_status(kNeedSyncHostToDevice);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void KernelRuntime::InitGraphInputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler,
|
||||
const session::KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(mem_scheduler);
|
||||
auto &input_nodes = graph->input_nodes();
|
||||
auto &input_tensors = graph->input_tensors();
|
||||
if (input_tensors.size() != input_nodes.size()) {
|
||||
MS_LOG_EXCEPTION << "Invalid input tensor size:" << input_tensors.size() << " vs node size:" << input_nodes.size();
|
||||
}
|
||||
for (size_t i = 0; i < input_tensors.size(); ++i) {
|
||||
auto tensor = input_tensors[i];
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
auto input_node = input_nodes[i];
|
||||
if (!input_node->isa<Parameter>()) {
|
||||
continue;
|
||||
}
|
||||
auto input_param = input_node->cast<ParameterPtr>();
|
||||
if (AnfAlgo::OutputAddrExist(input_node, 0)) {
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
MemPriority priority = kMemPriorityHigh;
|
||||
auto tensor_address = tensor->device_address();
|
||||
if (tensor_address != nullptr && tensor_address != device_address) {
|
||||
tensor->data_sync(false);
|
||||
priority = kMemPriorityLow;
|
||||
}
|
||||
mem_scheduler->Init(device_address.get(), tensor->data_c(), tensor->data().nbytes(), priority);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool KernelRuntime::LaunchKernel(const session::KernelGraph *graph, const AnfNodePtr &kernel,
|
||||
const std::shared_ptr<MemScheduler> &mem_scheduler, bool mock) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
AddressPtrList kernel_inputs;
|
||||
AddressPtrList kernel_workspaces;
|
||||
AddressPtrList kernel_outputs;
|
||||
GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
|
||||
bool ret;
|
||||
if (AnfAlgo::IsCommunicationOp(kernel)) {
|
||||
if (pynative_mode_profiling_flag_) {
|
||||
ret = LaunchKernelWithPynativeProfiling(kernel_mod, kernel->fullname_with_scope(), kernel_inputs,
|
||||
kernel_workspaces, kernel_outputs, communication_stream_);
|
||||
auto stream = kernel_mod->GetStream();
|
||||
if (stream == nullptr) {
|
||||
if (AnfAlgo::IsCommunicationOp(kernel)) {
|
||||
stream = communication_stream_;
|
||||
} else {
|
||||
ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, communication_stream_);
|
||||
stream = stream_;
|
||||
}
|
||||
}
|
||||
bool ret = true;
|
||||
if (mem_scheduler != nullptr) {
|
||||
ret = mem_scheduler->PreCompute(stream);
|
||||
if (!ret) {
|
||||
return ret;
|
||||
}
|
||||
AssignKernelAddress(mem_scheduler, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
|
||||
} else if (!kernel_mod->GetInputsAddr().empty() || !kernel_mod->GetOutputsAddr().empty()) {
|
||||
kernel_inputs = kernel_mod->GetInputsAddr();
|
||||
kernel_outputs = kernel_mod->GetOutputsAddr();
|
||||
kernel_workspaces = kernel_mod->GetWorkSpacesAddr();
|
||||
} else {
|
||||
GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
|
||||
}
|
||||
if (!mock) {
|
||||
if (pynative_mode_profiling_flag_) {
|
||||
ret = LaunchKernelWithPynativeProfiling(kernel_mod, kernel->fullname_with_scope(), kernel_inputs,
|
||||
kernel_workspaces, kernel_outputs, stream_);
|
||||
kernel_workspaces, kernel_outputs, stream);
|
||||
} else {
|
||||
ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
|
||||
ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream);
|
||||
}
|
||||
}
|
||||
if (mem_scheduler != nullptr) {
|
||||
SyncNodeOutputTensors(mem_scheduler, graph, kernel, mock);
|
||||
ret = mem_scheduler->PostCompute(stream);
|
||||
if (!ret) {
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
|
||||
const auto &kernels = graph.execution_order();
|
||||
bool KernelRuntime::LaunchKernelMod(const session::KernelGraph *graph, bool mock) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
std::shared_ptr<MemScheduler> mem_scheduler = nullptr;
|
||||
auto enable_mem_scheduler = context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER);
|
||||
if (enable_mem_scheduler) {
|
||||
mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph->graph_id());
|
||||
MS_EXCEPTION_IF_NULL(mem_scheduler);
|
||||
mem_scheduler->SetMemHandler(mem_manager_);
|
||||
mem_scheduler->RecordMemUsage();
|
||||
InitGraphInputTensors(mem_scheduler, graph);
|
||||
}
|
||||
const auto &kernels = graph->execution_order();
|
||||
std::vector<DynamicKernelPtr> dynamic_kernel_list;
|
||||
auto iter = graph_dynamic_kernel_map_.find(graph.graph_id());
|
||||
auto iter = graph_dynamic_kernel_map_.find(graph->graph_id());
|
||||
if (iter != graph_dynamic_kernel_map_.end()) {
|
||||
dynamic_kernel_list = iter->second;
|
||||
}
|
||||
|
@ -1259,7 +1479,7 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
|
|||
}
|
||||
std::vector<std::vector<std::function<void()>>> kernel_pre_run_events;
|
||||
std::vector<std::vector<std::function<void()>>> kernel_post_run_events;
|
||||
auto events_iter = graph_kernel_events_map_.find(graph.graph_id());
|
||||
auto events_iter = graph_kernel_events_map_.find(graph->graph_id());
|
||||
if (events_iter != graph_kernel_events_map_.end()) {
|
||||
kernel_pre_run_events = events_iter->second.first;
|
||||
kernel_post_run_events = events_iter->second.second;
|
||||
|
@ -1271,7 +1491,6 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
|
|||
dynamic_kernel_list[i]->InferShape();
|
||||
dynamic_kernel_list[i]->UpdateArgs();
|
||||
dynamic_kernel_list[i]->Execute();
|
||||
|
||||
if (!SyncStream()) {
|
||||
MS_LOG(ERROR) << "SyncStream failed";
|
||||
return false;
|
||||
|
@ -1292,7 +1511,7 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
|
|||
}
|
||||
continue;
|
||||
}
|
||||
auto ret = LaunchKernel(kernel);
|
||||
auto ret = LaunchKernel(graph, kernel, mem_scheduler, mock);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "Launch kernel failed.";
|
||||
return false;
|
||||
|
@ -1302,16 +1521,42 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
|
|||
}
|
||||
LaunchKernelEvent(kernel_post_run_events, i);
|
||||
}
|
||||
if (mem_scheduler != nullptr) {
|
||||
mem_scheduler->OptMemUsage();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void KernelRuntime::UseMemSchedulerIfNeeded(const session::KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
auto enable_mem_scheduler = context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER);
|
||||
if (enable_mem_scheduler) {
|
||||
auto mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph->graph_id());
|
||||
if (mem_scheduler->need_record_event()) {
|
||||
(void)LaunchKernelMod(graph, true);
|
||||
}
|
||||
float mem_used_factor = kMaxMemReuseFactor;
|
||||
while (!mem_scheduler->optimized() && mem_used_factor >= kMinMemReuseFactor) {
|
||||
mem_scheduler->SetMemUsedFactor(mem_used_factor);
|
||||
bool ret = LaunchKernelMod(graph, true);
|
||||
if (ret) {
|
||||
mem_scheduler->SetOptimized(true);
|
||||
} else {
|
||||
mem_used_factor -= kRetryFactor;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool KernelRuntime::LaunchKernels(const session::KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (!LaunchKernelMod(*graph)) {
|
||||
UseMemSchedulerIfNeeded(graph);
|
||||
if (!LaunchKernelMod(graph)) {
|
||||
MS_LOG(ERROR) << "LaunchKernelMod failed!";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
|
||||
|
|
|
@ -33,6 +33,7 @@
|
|||
#include "backend/kernel_compiler/kernel.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "runtime/device/memory_manager.h"
|
||||
#include "runtime/device/memory_scheduler.h"
|
||||
#include "runtime/device/executor/dynamic_kernel.h"
|
||||
#include "ir/device_event.h"
|
||||
|
||||
|
@ -69,7 +70,6 @@ class KernelRuntime {
|
|||
virtual bool GenDynamicKernel(const session::KernelGraph *graph) = 0;
|
||||
virtual bool RunDynamicKernelAsync(const session::KernelGraph *graph) = 0;
|
||||
bool LaunchKernels(const session::KernelGraph *graph);
|
||||
virtual bool LaunchKernel(const AnfNodePtr &kernel);
|
||||
virtual void AssignStaticMemoryInput(const session::KernelGraph *graph);
|
||||
virtual void AssignStaticMemoryValueNode(session::KernelGraph *graph);
|
||||
virtual void ClearGraphRuntimeResource(uint32_t graph_id);
|
||||
|
@ -141,11 +141,24 @@ class KernelRuntime {
|
|||
virtual void KernelLaunchProfiling(const std::string &kernel_name) {}
|
||||
|
||||
private:
|
||||
void UseMemSchedulerIfNeeded(const session::KernelGraph *graph);
|
||||
bool LaunchKernel(const session::KernelGraph *graph, const AnfNodePtr &kernel,
|
||||
const std::shared_ptr<MemScheduler> &mem_scheduler, bool mock = false);
|
||||
void ResetNodeAddress(session::KernelGraph *graph);
|
||||
void AssignKernelAddress(const std::shared_ptr<MemScheduler> &mem_scheduler, const AnfNodePtr &kernel,
|
||||
AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces,
|
||||
AddressPtrList *kernel_outputs);
|
||||
static void GetOrMallocAddress(const std::shared_ptr<MemScheduler> &mem_scheduler,
|
||||
const DeviceAddress *device_address, const kernel::AddressPtr &kernel_addr);
|
||||
void InitGraphInputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler, const session::KernelGraph *graph);
|
||||
void SyncNodeOutputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler, const session::KernelGraph *graph,
|
||||
const AnfNodePtr &kernel, bool mock);
|
||||
void AssignStaticMemoryOutput(const session::KernelGraph *graph);
|
||||
bool LaunchKernelMod(const session::KernelGraph &graph);
|
||||
bool LaunchKernelMod(const session::KernelGraph *graph, bool mock = false);
|
||||
void LaunchKernelEvent(const std::vector<std::vector<std::function<void()>>> &run_events, size_t index) const;
|
||||
void DebugStreamSync(const CNodePtr &kernel);
|
||||
static void GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs);
|
||||
static void GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs,
|
||||
const std::shared_ptr<MemScheduler> &mem_schedule = nullptr);
|
||||
void RunOpAssignInputMemory(const std::vector<tensor::TensorPtr> &input_tensors, const session::KernelGraph *graph);
|
||||
void RunOpAssignOutputMemory(const AnfNodePtr &kernel,
|
||||
const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node = {});
|
||||
|
@ -179,6 +192,7 @@ class KernelRuntime {
|
|||
std::map<uint32_t,
|
||||
std::pair<std::vector<std::vector<std::function<void()>>>, std::vector<std::vector<std::function<void()>>>>>
|
||||
graph_kernel_events_map_;
|
||||
MemSchedulerManager mem_scheduler_manager_;
|
||||
};
|
||||
using KernelRuntimePtr = std::shared_ptr<KernelRuntime>;
|
||||
} // namespace device
|
||||
|
|
|
@ -19,8 +19,11 @@
|
|||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <queue>
|
||||
#include "backend/optimizer/mem_reuse/mem_reuse.h"
|
||||
#include "backend/optimizer/somas/somas.h"
|
||||
#include "runtime/device/memory_scheduler.h"
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
enum MemType { kStaticMem, kDynamicMem, kSomasReuseDynamicMem };
|
||||
|
@ -28,7 +31,7 @@ const int kGetAllOuts = -1;
|
|||
const uint64_t kMemAlignSize = 512;
|
||||
using SomasPtr = mindspore::somas::SomasPtr;
|
||||
|
||||
class MemoryManager {
|
||||
class MemoryManager : public MemHandler {
|
||||
public:
|
||||
MemoryManager() = default;
|
||||
virtual ~MemoryManager() = default;
|
||||
|
@ -60,6 +63,45 @@ class MemoryManager {
|
|||
static size_t GetCommonAlignSize(size_t input_size);
|
||||
static size_t GetCommunicationAlignSize(size_t input_size);
|
||||
|
||||
// swap manager interface
|
||||
void *MallocDevice(size_t mem_size) override { return MallocMemFromMemPool(mem_size); }
|
||||
void FreeDevice(void *ptr) override {
|
||||
MS_EXCEPTION_IF_NULL(ptr);
|
||||
FreeMemFromMemPool(ptr);
|
||||
}
|
||||
void *MallocHost(size_t mem_size) override {
|
||||
auto &mem_que = cached_host_mem_[mem_size];
|
||||
if (!mem_que.empty()) {
|
||||
auto ret = mem_que.front();
|
||||
mem_que.pop();
|
||||
return ret;
|
||||
}
|
||||
auto block = std::make_shared<std::vector<uint8_t>>();
|
||||
try {
|
||||
block->resize(mem_size, 0);
|
||||
auto ptr = block->data();
|
||||
host_mem_block_map_[ptr] = block;
|
||||
return ptr;
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(EXCEPTION) << "Malloc memory failed: size " << mem_size;
|
||||
}
|
||||
}
|
||||
void FreeHost(void *ptr) override {
|
||||
MS_EXCEPTION_IF_NULL(ptr);
|
||||
auto iter = host_mem_block_map_.find(ptr);
|
||||
if (iter == host_mem_block_map_.end()) {
|
||||
MS_LOG(ERROR) << "Free ptr not be created from manager!";
|
||||
}
|
||||
auto mem_size = iter->second->size();
|
||||
cached_host_mem_[mem_size].emplace(iter->first);
|
||||
}
|
||||
void SwapIn(void *host_ptr, void *device_ptr, size_t mem_size, void *stream) override {}
|
||||
void SwapOut(void *device_ptr, void *host_ptr, size_t mem_size, void *stream) override {}
|
||||
size_t GetAvailableMemSize() override {
|
||||
MS_LOG(ERROR) << "Return default 0 mem size!";
|
||||
return 0;
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual uint8_t *MallocStaticMem(size_t size, bool communication_mem, uint32_t graph_id = kInvalidGraphId) = 0;
|
||||
virtual uint8_t *MallocDynamicMem(size_t size, bool communication_mem);
|
||||
|
@ -70,6 +112,8 @@ class MemoryManager {
|
|||
size_t total_static_size_ = 0;
|
||||
size_t total_dynamic_size_ = 0;
|
||||
SomasPtr somas_reuse_util_ptr_{nullptr};
|
||||
std::map<size_t, std::queue<void *>> cached_host_mem_;
|
||||
std::map<void *, std::shared_ptr<std::vector<uint8_t>>> host_mem_block_map_;
|
||||
};
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,377 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/device/memory_scheduler.h"
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
#include "utils/log_adapter.h"
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
void MemScheduler::Clear() {
|
||||
if (mem_handler_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
for (auto &item : high_priority_device_ptr_) {
|
||||
mem_handler_->FreeDevice(item.second);
|
||||
}
|
||||
high_priority_device_ptr_.clear();
|
||||
}
|
||||
|
||||
bool MemScheduler::IsHighPriorityMem(const void *key) {
|
||||
auto iter = mem_priority_.find(key);
|
||||
if (iter != mem_priority_.end()) {
|
||||
return iter->second == kMemPriorityHigh;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void MemScheduler::SetMemPriority(const void *key, MemPriority priority) { mem_priority_[key] = priority; }
|
||||
|
||||
void MemScheduler::Record(const void *key, const EventType &event_type, size_t mem_size) {
|
||||
if (key == nullptr) {
|
||||
return;
|
||||
}
|
||||
auto event = std::make_shared<Event>(event_type, compute_index_);
|
||||
event->mem_size = mem_size;
|
||||
event->key = key;
|
||||
mem_events_[key].emplace_back(event);
|
||||
}
|
||||
|
||||
void MemScheduler::Init(const void *key, void *host_ptr, size_t mem_size, MemPriority priority) {
|
||||
if (need_record_event_) {
|
||||
mem_priority_[key] = priority;
|
||||
Record(key, kInit, mem_size);
|
||||
} else {
|
||||
init_host_ptr_[key] = host_ptr;
|
||||
}
|
||||
}
|
||||
|
||||
void *MemScheduler::GetOrMalloc(const void *key, size_t mem_size, MemPriority priority) {
|
||||
if (need_record_event_) {
|
||||
if (mem_priority_.find(key) == mem_priority_.end()) {
|
||||
mem_priority_[key] = priority;
|
||||
Record(key, kMalloc, mem_size);
|
||||
} else {
|
||||
Record(key, kGet, mem_size);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
auto iter = mem_result_.find(key);
|
||||
if (iter != mem_result_.end()) {
|
||||
auto ptr = iter->second;
|
||||
MS_EXCEPTION_IF_NULL(ptr);
|
||||
return ptr;
|
||||
} else {
|
||||
MS_LOG_EXCEPTION << "Mem extender get nullptr result!";
|
||||
}
|
||||
}
|
||||
|
||||
bool MemScheduler::PreCompute(void *stream) {
|
||||
if (need_record_event_) {
|
||||
return true;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(mem_handler_);
|
||||
if (pre_compute_events_.size() <= compute_index_) {
|
||||
MS_LOG_EXCEPTION << "Index out of pre event range, index:" << compute_index_
|
||||
<< ", event size:" << pre_compute_events_.size();
|
||||
}
|
||||
auto &events = pre_compute_events_[compute_index_];
|
||||
for (auto &event : events) {
|
||||
MS_EXCEPTION_IF_NULL(event);
|
||||
MS_LOG(DEBUG) << "Pre compute " << compute_index_ << ": " << event->key << " v " << event->type;
|
||||
if (event->type == kInit) {
|
||||
auto host_ptr = init_host_ptr_[event->key];
|
||||
MS_EXCEPTION_IF_NULL(host_ptr);
|
||||
auto priority = mem_priority_[event->key];
|
||||
auto iter = high_priority_device_ptr_.find(event->key);
|
||||
if (priority != kMemPriorityLow && iter != high_priority_device_ptr_.end()) {
|
||||
MS_EXCEPTION_IF_NULL(iter->second);
|
||||
mem_result_[event->key] = iter->second;
|
||||
if (priority == kMemPriorityMedium) {
|
||||
mem_handler_->SwapIn(host_ptr, iter->second, event->mem_size, stream);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
auto device_ptr = mem_handler_->MallocDevice(event->mem_size);
|
||||
if (device_ptr == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if (priority != kMemPriorityLow) {
|
||||
high_priority_device_ptr_[event->key] = device_ptr;
|
||||
}
|
||||
mem_handler_->SwapIn(host_ptr, device_ptr, event->mem_size, stream);
|
||||
mem_result_[event->key] = device_ptr;
|
||||
} else if (event->type == kMalloc) {
|
||||
auto device_ptr = mem_handler_->MallocDevice(event->mem_size);
|
||||
if (device_ptr == nullptr) {
|
||||
return false;
|
||||
}
|
||||
mem_result_[event->key] = device_ptr;
|
||||
} else if (event->type == kSwapIn) {
|
||||
bool from_init = true;
|
||||
auto host_ptr = init_host_ptr_[event->key];
|
||||
if (host_ptr == nullptr) {
|
||||
host_ptr = swap_host_ptr_[event->key];
|
||||
from_init = false;
|
||||
}
|
||||
auto device_ptr = mem_handler_->MallocDevice(event->mem_size);
|
||||
if (device_ptr == nullptr) {
|
||||
return false;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(host_ptr);
|
||||
mem_handler_->SwapIn(host_ptr, device_ptr, event->mem_size, stream);
|
||||
mem_result_[event->key] = device_ptr;
|
||||
if (!from_init) {
|
||||
mem_handler_->FreeHost(host_ptr);
|
||||
swap_host_ptr_.erase(event->key);
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MemScheduler::PostCompute(void *stream) {
|
||||
if (need_record_event_) {
|
||||
++compute_index_;
|
||||
return true;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(mem_handler_);
|
||||
if (post_compute_events_.size() <= compute_index_) {
|
||||
MS_LOG_EXCEPTION << "Index out of post event range, index:" << compute_index_
|
||||
<< ", event size:" << post_compute_events_.size();
|
||||
}
|
||||
auto &events = post_compute_events_[compute_index_];
|
||||
for (auto &event : events) {
|
||||
MS_EXCEPTION_IF_NULL(event);
|
||||
MS_LOG(DEBUG) << "Post compute " << compute_index_ << ": " << event->key << " v " << event->type;
|
||||
if (event->type == kFree) {
|
||||
auto ptr = mem_result_[event->key];
|
||||
if (ptr == nullptr) {
|
||||
return false;
|
||||
}
|
||||
mem_handler_->FreeDevice(ptr);
|
||||
(void)mem_result_.erase(event->key);
|
||||
} else if (event->type == kSwapOut) {
|
||||
auto device_ptr = mem_result_[event->key];
|
||||
if (device_ptr == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto host_ptr = init_host_ptr_[event->key];
|
||||
if (host_ptr == nullptr) {
|
||||
host_ptr = mem_handler_->MallocHost(event->mem_size);
|
||||
swap_host_ptr_[event->key] = host_ptr;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(host_ptr);
|
||||
mem_handler_->SwapOut(device_ptr, host_ptr, event->mem_size, stream);
|
||||
mem_handler_->FreeDevice(device_ptr);
|
||||
mem_result_.erase(device_ptr);
|
||||
}
|
||||
}
|
||||
++compute_index_;
|
||||
return true;
|
||||
}
|
||||
|
||||
void MemScheduler::OptMemUsage() {
|
||||
need_record_event_ = false;
|
||||
if (optimized_) {
|
||||
return;
|
||||
}
|
||||
CountMemUsage();
|
||||
CheckMemSize();
|
||||
if (need_swap_) {
|
||||
GenEventSpan();
|
||||
GenNoSwapEventSet();
|
||||
}
|
||||
GenEvents();
|
||||
}
|
||||
|
||||
void MemScheduler::CountMemUsage() {
|
||||
if (!min_mem_used_.empty()) {
|
||||
return;
|
||||
}
|
||||
min_mem_used_.resize(compute_index_, 0);
|
||||
std::vector<size_t> total_mem_used(compute_index_, 0);
|
||||
for (auto &item : mem_events_) {
|
||||
auto &mem_events = item.second;
|
||||
if (mem_events.empty()) {
|
||||
continue;
|
||||
}
|
||||
auto first_event = mem_events[0];
|
||||
MS_EXCEPTION_IF_NULL(first_event);
|
||||
size_t i = 0;
|
||||
if (first_event->type == kInit && mem_events.size() > 1) {
|
||||
first_event = mem_events[1];
|
||||
i = 1;
|
||||
}
|
||||
auto last_event = mem_events[mem_events.size() - 1];
|
||||
for (size_t start_index = first_event->index; start_index <= last_event->index; ++start_index) {
|
||||
if (start_index < compute_index_) {
|
||||
total_mem_used[start_index] += first_event->mem_size;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Error mem event index " << start_index;
|
||||
}
|
||||
}
|
||||
for (; i < mem_events.size(); ++i) {
|
||||
auto &event = mem_events[i];
|
||||
MS_EXCEPTION_IF_NULL(event);
|
||||
if (event->index < compute_index_) {
|
||||
min_mem_used_[event->index] += first_event->mem_size;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Error mem event index " << event->index;
|
||||
}
|
||||
}
|
||||
}
|
||||
min_mem_needed_ = *(std::max_element(min_mem_used_.begin(), min_mem_used_.end()));
|
||||
mem_used_without_swap_ = *(std::max_element(total_mem_used.begin(), total_mem_used.end()));
|
||||
}
|
||||
|
||||
void MemScheduler::CheckMemSize() {
|
||||
MS_EXCEPTION_IF_NULL(mem_handler_);
|
||||
auto available_mem_size = mem_handler_->GetAvailableMemSize();
|
||||
if (available_mem_size < min_mem_needed_) {
|
||||
MS_LOG(EXCEPTION) << "Out of memory, as available mem size is " << available_mem_size
|
||||
<< " while graph needs at least " << min_mem_needed_;
|
||||
}
|
||||
if (mem_used_without_swap_ > available_mem_size) {
|
||||
need_swap_ = true;
|
||||
}
|
||||
MS_LOG(INFO) << "Available mem size: " << available_mem_size << ", graph needs mem size:" << mem_used_without_swap_
|
||||
<< "without swap, and needs at least " << min_mem_needed_ << " with swap.";
|
||||
}
|
||||
|
||||
void MemScheduler::GenEventSpan() {
|
||||
if (!event_span_.empty()) {
|
||||
return;
|
||||
}
|
||||
for (auto &item : mem_events_) {
|
||||
auto &mem_events = item.second;
|
||||
if (mem_events.empty()) {
|
||||
continue;
|
||||
}
|
||||
auto first_event = mem_events[0];
|
||||
MS_EXCEPTION_IF_NULL(first_event);
|
||||
size_t i = 0;
|
||||
if (first_event->type == kInit && mem_events.size() > 1) {
|
||||
first_event = mem_events[1];
|
||||
i = 1;
|
||||
}
|
||||
size_t last_index = first_event->index;
|
||||
for (; i < mem_events.size(); ++i) {
|
||||
auto &event = mem_events[i];
|
||||
MS_EXCEPTION_IF_NULL(event);
|
||||
auto span = event->index - last_index;
|
||||
if (span > 1) {
|
||||
(void)event_span_.insert(std::pair<size_t, std::shared_ptr<Event>>(span, event));
|
||||
}
|
||||
last_index = event->index;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MemScheduler::GenNoSwapEventSet() {
|
||||
MS_EXCEPTION_IF_NULL(mem_handler_);
|
||||
auto available_mem_size = mem_handler_->GetAvailableMemSize();
|
||||
auto threshold = available_mem_size * mem_used_factor_;
|
||||
no_swap_events_.clear();
|
||||
std::vector<size_t> cur_mem_used(min_mem_used_.begin(), min_mem_used_.end());
|
||||
for (auto iter = event_span_.begin(); iter != event_span_.end(); ++iter) {
|
||||
auto span = iter->first;
|
||||
auto &event = iter->second;
|
||||
auto start_index = event->index - span + 1;
|
||||
bool revert = false;
|
||||
for (size_t i = start_index; i < event->index; ++i) {
|
||||
cur_mem_used[i] += event->mem_size;
|
||||
if (cur_mem_used[i] > threshold) {
|
||||
revert = true;
|
||||
}
|
||||
}
|
||||
if (revert) {
|
||||
for (size_t i = start_index; i < event->index; ++i) {
|
||||
cur_mem_used[i] -= event->mem_size;
|
||||
}
|
||||
} else {
|
||||
no_swap_events_.emplace(event);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MemScheduler::GenEvents() {
|
||||
pre_compute_events_.resize(compute_index_);
|
||||
post_compute_events_.resize(compute_index_);
|
||||
for (auto &item : mem_events_) {
|
||||
auto &mem_events = item.second;
|
||||
if (mem_events.empty()) {
|
||||
continue;
|
||||
}
|
||||
auto first_event = mem_events[0];
|
||||
MS_EXCEPTION_IF_NULL(first_event);
|
||||
if (first_event->type == kInit) {
|
||||
if (mem_events.size() > 1) {
|
||||
auto &second_event = mem_events[1];
|
||||
MS_EXCEPTION_IF_NULL(second_event);
|
||||
first_event->index = second_event->index;
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if ((first_event->type == kInit || first_event->type == kMalloc) &&
|
||||
first_event->index < pre_compute_events_.size()) {
|
||||
pre_compute_events_[first_event->index].emplace_back(first_event);
|
||||
} else {
|
||||
MS_LOG_EXCEPTION << "First event should be init or malloc!";
|
||||
}
|
||||
MemPriority priority = kMemPriorityLow;
|
||||
auto iter = mem_priority_.find(first_event->key);
|
||||
if (iter != mem_priority_.end()) {
|
||||
priority = iter->second;
|
||||
}
|
||||
size_t pre_index = first_event->index;
|
||||
for (size_t i = 1; i < mem_events.size(); ++i) {
|
||||
auto &event = mem_events[i];
|
||||
MS_EXCEPTION_IF_NULL(event);
|
||||
if (need_swap_ && event->index - pre_index > 1 && priority == kMemPriorityLow &&
|
||||
no_swap_events_.find(event) == no_swap_events_.end()) {
|
||||
auto swap_out_event = std::make_shared<Event>(kSwapOut, pre_index);
|
||||
swap_out_event->key = item.first;
|
||||
swap_out_event->mem_size = first_event->mem_size;
|
||||
post_compute_events_[pre_index].emplace_back(swap_out_event);
|
||||
auto swap_in_event = std::make_shared<Event>(kSwapIn, event->index);
|
||||
swap_in_event->key = item.first;
|
||||
swap_in_event->mem_size = first_event->mem_size;
|
||||
pre_compute_events_[event->index].emplace_back(swap_in_event);
|
||||
}
|
||||
if (event->index < pre_compute_events_.size()) {
|
||||
pre_compute_events_[event->index].emplace_back(event);
|
||||
}
|
||||
pre_index = event->index;
|
||||
}
|
||||
if (priority != kMemPriorityLow) {
|
||||
continue;
|
||||
}
|
||||
auto &last_event = mem_events[mem_events.size() - 1];
|
||||
MS_EXCEPTION_IF_NULL(last_event);
|
||||
auto free_event = std::make_shared<Event>(kFree, last_event->index);
|
||||
free_event->key = item.first;
|
||||
if (last_event->index < post_compute_events_.size()) {
|
||||
post_compute_events_[last_event->index].emplace_back(free_event);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,143 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_SCHEDULER_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_SCHEDULER_H_
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
class MemHandler {
|
||||
public:
|
||||
virtual size_t GetAvailableMemSize() = 0;
|
||||
virtual void *MallocDevice(size_t mem_size) = 0;
|
||||
virtual void FreeDevice(void *ptr) = 0;
|
||||
virtual void *MallocHost(size_t mem_size) = 0;
|
||||
virtual void FreeHost(void *ptr) = 0;
|
||||
virtual void SwapIn(void *host_ptr, void *device_ptr, size_t mem_size, void *stream) = 0;
|
||||
virtual void SwapOut(void *device_ptr, void *host_ptr, size_t mem_size, void *stream) = 0;
|
||||
};
|
||||
|
||||
enum MemPriority { kMemPriorityLow, kMemPriorityMedium, kMemPriorityHigh };
|
||||
|
||||
class MemScheduler {
|
||||
enum EventType { kInit, kMalloc, kGet, kFree, kSwapIn, kSwapOut };
|
||||
|
||||
struct Event {
|
||||
Event(const EventType &in_type, size_t in_index) {
|
||||
type = in_type;
|
||||
index = in_index;
|
||||
}
|
||||
|
||||
EventType type;
|
||||
size_t index{0};
|
||||
size_t mem_size{0};
|
||||
const void *key{nullptr};
|
||||
};
|
||||
|
||||
public:
|
||||
MemScheduler() = default;
|
||||
~MemScheduler() = default;
|
||||
|
||||
bool need_record_event() const { return need_record_event_; }
|
||||
|
||||
bool optimized() const { return optimized_; }
|
||||
|
||||
void SetOptimized(bool flag) { optimized_ = flag; }
|
||||
|
||||
void SetMemHandler(const std::shared_ptr<MemHandler> &handler) { mem_handler_ = handler; }
|
||||
|
||||
void Init(const void *key, void *host_ptr, size_t mem_size, MemPriority priority = kMemPriorityLow);
|
||||
|
||||
void *GetOrMalloc(const void *key, size_t mem_size, MemPriority priority = kMemPriorityLow);
|
||||
|
||||
void RecordMemUsage() { compute_index_ = 0; }
|
||||
|
||||
bool PreCompute(void *stream);
|
||||
|
||||
bool PostCompute(void *stream);
|
||||
|
||||
void OptMemUsage();
|
||||
|
||||
void Clear();
|
||||
|
||||
bool IsHighPriorityMem(const void *key);
|
||||
|
||||
void SetMemPriority(const void *key, MemPriority priority);
|
||||
|
||||
void SetMemUsedFactor(float factor) { mem_used_factor_ = factor; }
|
||||
|
||||
void SetNeedSwap(bool flag) { need_swap_ = flag; }
|
||||
|
||||
private:
|
||||
void Record(const void *key, const EventType &event_type, size_t mem_size = 0);
|
||||
void GenEvents();
|
||||
void CheckMemSize();
|
||||
void CountMemUsage();
|
||||
void GenEventSpan();
|
||||
void GenNoSwapEventSet();
|
||||
std::map<const void *, MemPriority> mem_priority_;
|
||||
std::map<const void *, std::vector<std::shared_ptr<Event>>> mem_events_;
|
||||
std::vector<std::vector<std::shared_ptr<Event>>> pre_compute_events_;
|
||||
std::vector<std::vector<std::shared_ptr<Event>>> post_compute_events_;
|
||||
std::map<const void *, void *> mem_result_;
|
||||
std::map<const void *, void *> init_host_ptr_;
|
||||
std::map<const void *, void *> swap_host_ptr_;
|
||||
std::map<const void *, void *> high_priority_device_ptr_;
|
||||
size_t compute_index_{0};
|
||||
bool need_record_event_{true};
|
||||
bool optimized_{false};
|
||||
std::shared_ptr<MemHandler> mem_handler_{nullptr};
|
||||
bool need_swap_{false};
|
||||
std::multimap<size_t, std::shared_ptr<Event>> event_span_;
|
||||
std::set<std::shared_ptr<Event>> no_swap_events_;
|
||||
std::vector<size_t> min_mem_used_;
|
||||
size_t mem_used_without_swap_{0};
|
||||
size_t min_mem_needed_{0};
|
||||
float mem_used_factor_{0.9};
|
||||
};
|
||||
|
||||
class MemSchedulerManager {
|
||||
public:
|
||||
MemSchedulerManager() = default;
|
||||
~MemSchedulerManager() = default;
|
||||
std::shared_ptr<MemScheduler> GetOrCreateMemScheduler(uint64_t uid) {
|
||||
auto scheduler = GetMemScheduler(uid);
|
||||
if (scheduler == nullptr) {
|
||||
scheduler = std::make_shared<MemScheduler>();
|
||||
graph_mem_scheduler_map_[uid] = scheduler;
|
||||
}
|
||||
return scheduler;
|
||||
}
|
||||
|
||||
std::shared_ptr<MemScheduler> GetMemScheduler(uint64_t uid) {
|
||||
auto iter = graph_mem_scheduler_map_.find(uid);
|
||||
if (iter != graph_mem_scheduler_map_.end()) {
|
||||
return iter->second;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
std::map<uint64_t, std::shared_ptr<MemScheduler>> graph_mem_scheduler_map_;
|
||||
};
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_SCHEDULER_H_
|
|
@ -481,6 +481,7 @@ constexpr auto kValueTargetOther = "target_other";
|
|||
|
||||
// env key
|
||||
constexpr auto kGraphOpRun = "GRAPH_OP_RUN";
|
||||
constexpr auto kEnableMemScheduler = "ENABLE_MEM_SCHEDULER";
|
||||
|
||||
// some size
|
||||
const size_t kShape4dDims = 4;
|
||||
|
|
|
@ -602,8 +602,13 @@ BackendPtr CreateBackend() {
|
|||
backend->set_is_multi_graph_sink(false);
|
||||
context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
|
||||
} else {
|
||||
auto single_op = std::getenv(kGraphOpRun);
|
||||
if (single_op && std::string(single_op) == "1") {
|
||||
auto single_op = common::GetEnv(kGraphOpRun);
|
||||
if (single_op == "1") {
|
||||
context_ptr->set_param<bool>(MS_CTX_ENABLE_TASK_SINK, false);
|
||||
}
|
||||
auto enable_mem_scheduler = common::GetEnv(kEnableMemScheduler);
|
||||
if (enable_mem_scheduler == "1") {
|
||||
context_ptr->set_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER, true);
|
||||
context_ptr->set_param<bool>(MS_CTX_ENABLE_TASK_SINK, false);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -81,6 +81,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
|
|||
set_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE, "0");
|
||||
set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, target == kAscendDevice || target == kDavinciDevice);
|
||||
set_param<bool>(MS_CTX_ENABLE_PROFILING, false);
|
||||
set_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER, false);
|
||||
set_param<std::string>(MS_CTX_PROFILING_OPTIONS, "training_trace");
|
||||
set_param<bool>(MS_CTX_CHECK_BPROP_FLAG, false);
|
||||
set_param<float>(MS_CTX_MAX_DEVICE_MEMORY, kDefaultMaxDeviceMemory);
|
||||
|
|
|
@ -72,6 +72,7 @@ enum MsCtxParam : unsigned {
|
|||
MS_CTX_ENABLE_GRAPH_KERNEL,
|
||||
MS_CTX_ENABLE_HCCL,
|
||||
MS_CTX_ENABLE_LOOP_SINK,
|
||||
MS_CTX_ENABLE_MEM_SCHEDULER,
|
||||
MS_CTX_ENABLE_PYNATIVE_HOOK,
|
||||
MS_CTX_ENABLE_PYNATIVE_INFER,
|
||||
MS_CTX_ENABLE_REDUCE_PRECISION,
|
||||
|
|
Loading…
Reference in New Issue