!22092 [executor] Add mem scheduler

Merge pull request !22092 from kisnwang/add-mem-extend-cache
This commit is contained in:
i-robot 2021-09-16 07:40:00 +00:00 committed by Gitee
commit 6f09891501
22 changed files with 992 additions and 80 deletions

View File

@ -43,11 +43,8 @@ class AscendKernelMod : public KernelMod {
return false;
#endif
}
void SetStream(void *stream) { stream_ = stream; }
void *GetStream() { return stream_; }
protected:
void *stream_{nullptr};
uint32_t block_dim_{1};
uint32_t stream_id_{0};
};

View File

@ -195,11 +195,14 @@ class KernelMod {
const std::vector<AddressPtr> &GetInputsAddr() { return inputs_addr_; }
const std::vector<AddressPtr> &GetWorkSpacesAddr() { return workspaces_addr_; }
const std::vector<AddressPtr> &GetOutputsAddr() { return outputs_addr_; }
void SetStream(void *stream) { stream_ = stream; }
void *GetStream() const { return stream_; }
protected:
std::string unique_name_;
std::string fullname_;
bool is_monad_{false};
void *stream_{nullptr};
private:
std::vector<AddressPtr> inputs_addr_;

View File

@ -45,18 +45,6 @@ using DeviceAddress = device::DeviceAddress;
using DeviceAddressPtr = device::DeviceAddressPtr;
using Address = kernel::Address;
using AddressPtr = kernel::AddressPtr;
using KernelWithIndex = std::pair<AnfNodePtr, size_t>;
struct KernelWithIndexCmp {
bool operator()(const KernelWithIndex &key1, const KernelWithIndex &key2) const {
if (key1.first != key2.first) {
return key1.first < key2.first;
}
if (key1.second != key2.second) {
return key1.second < key2.second;
}
return false;
}
};
class OpRuntimeInfo {
public:

View File

@ -413,13 +413,18 @@ void AscendSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_gra
MS_LOG(EXCEPTION) << "Tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size()
<< ", input_ctrl_size:" << input_ctrl_size;
}
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
auto enable_mem_scheduler = ms_context->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER);
if (enable_mem_scheduler) {
kernel_graph->SetInputTensors(inputs);
return;
}
for (auto item : tensor_device_addr_map_) {
auto output_tensor = item.first;
output_tensor->set_device_address(item.second);
}
SyncStream();
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
for (size_t i = 0; i < inputs.size(); ++i) {
auto tensor = inputs[i];
MS_EXCEPTION_IF_NULL(tensor);
@ -655,7 +660,10 @@ void AscendSession::BuildGraphImpl(GraphId graph_id) {
} else {
// alloc memory, including static memory and dynamic memory
MemoryAlloc(graph.get());
AnfAlgo::CacheAddrForGraph(graph);
auto enable_mem_scheduler = ms_context->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER);
if (!enable_mem_scheduler) {
AnfAlgo::CacheAddrForGraph(graph);
}
// generate and load task info to device if it is sink mode
Load(graph);
}
@ -690,10 +698,13 @@ void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) {
// optimize graph
HardwareOptimize(child_graph);
// assign static memory of parameters
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
runtime_instance->AssignStaticMemoryInput(child_graph.get());
runtime_instance->AssignStaticMemoryValueNode(child_graph.get());
auto enable_mem_scheduler = context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER);
if (!enable_mem_scheduler) {
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
runtime_instance->AssignStaticMemoryInput(child_graph.get());
runtime_instance->AssignStaticMemoryValueNode(child_graph.get());
}
}
bool AscendSession::IsSupportSummary() { return !device::KernelAdjust::NeedInsertSwitch(); }
@ -1954,6 +1965,12 @@ void AscendSession::ExecuteAllTaskInQueue() {
void AscendSession::UpdateOutputTensors(const VectorRef *outputs,
const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node,
std::map<DeviceAddressPtr, DeviceAddressPtr> *) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
auto enable_mem_scheduler = context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER);
if (enable_mem_scheduler) {
return;
}
MS_EXCEPTION_IF_NULL(outputs);
tensor_device_addr_map_.clear();
for (const auto &item : *outputs) {

View File

@ -36,6 +36,21 @@
namespace mindspore {
namespace session {
using AnfWithOutIndex = std::pair<AnfNodePtr, size_t>;
using KernelWithIndex = std::pair<AnfNodePtr, size_t>;
struct KernelWithIndexCmp {
bool operator()(const KernelWithIndex &key1, const KernelWithIndex &key2) const {
if (key1.first != key2.first) {
return key1.first < key2.first;
}
if (key1.second != key2.second) {
return key1.second < key2.second;
}
return false;
}
};
using KernelMapTensor = std::map<session::KernelWithIndex, BaseRef, session::KernelWithIndexCmp>;
class KernelGraph : public FuncGraph {
public:
KernelGraph() : graph_id_(0), start_label_(nullptr), end_goto_(nullptr), current_epoch_(0), is_dynamic_shape_(false) {
@ -260,8 +275,21 @@ class KernelGraph : public FuncGraph {
void SetOptimizerFlag();
void SetInputNodes();
const std::vector<AnfNodePtr> &input_nodes() const { return input_nodes_; }
void SetInputTensors(const std::vector<tensor::TensorPtr> &input_tensors) { input_tensors_ = input_tensors; }
const std::vector<tensor::TensorPtr> &input_tensors() const { return input_tensors_; }
void SetOutputNodeToTensor(const KernelMapTensor &node_to_tensor) { output_node_to_tensor_ = node_to_tensor; }
tensor::TensorPtr GetNodeOutputTensor(const session::KernelWithIndex &output_index) const {
auto iter = output_node_to_tensor_.find(output_index);
if (iter != output_node_to_tensor_.end()) {
return utils::cast<tensor::TensorPtr>(iter->second);
}
return nullptr;
}
bool has_optimizer() const { return has_optimizer_; }
bool IsUpdatedParameter(const ParameterPtr &param) {
bool IsUpdatedParameter(const ParameterPtr &param) const {
if (updated_parameters_.find(param) != updated_parameters_.end()) {
return true;
}
@ -428,6 +456,8 @@ class KernelGraph : public FuncGraph {
std::map<AnfNodePtr, AnfNodePtr> edge_to_;
std::stack<AnfNodePtr> loop_nodes_;
std::vector<AnfNodePtr> input_nodes_;
std::vector<tensor::TensorPtr> input_tensors_;
KernelMapTensor output_node_to_tensor_;
std::unordered_map<uint32_t, std::weak_ptr<session::KernelGraph>> pre_graphs_;
std::unordered_map<uint32_t, std::weak_ptr<session::KernelGraph>> post_graphs_;
// The send/recv pairs inserted for allreduce, the key is allreduce kernel, the first of pair is send node, the second

View File

@ -1689,11 +1689,23 @@ void SessionBasic::CreateOutputTensors(const GraphId &graph_id, const std::vecto
MS_LOG(INFO) << "Create node output[" << item->DebugString() << "]";
outputs->emplace_back(CreateNodeOutputTensors(item, kernel_graph, input_tensors, tensor_to_node, &node_to_tensor));
}
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
auto enable_mem_scheduler = ms_context->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER);
if (enable_mem_scheduler) {
kernel_graph->SetOutputNodeToTensor(node_to_tensor);
}
}
void SessionBasic::UpdateOutputTensors(const VectorRef *outputs,
const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node,
std::map<DeviceAddressPtr, DeviceAddressPtr> *) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
auto enable_mem_scheduler = context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER);
if (enable_mem_scheduler) {
return;
}
MS_EXCEPTION_IF_NULL(outputs);
for (const auto &item : *outputs) {
if (utils::isa<VectorRefPtr>(item)) {

View File

@ -88,7 +88,6 @@ struct GraphOutputInfo {
std::vector<tensor::TensorPtr> graph_output_tensors;
};
using KernelMapTensor = std::map<session::KernelWithIndex, BaseRef, session::KernelWithIndexCmp>;
class Executor;
class SessionBasic : public std::enable_shared_from_this<SessionBasic> {

View File

@ -1267,8 +1267,9 @@ void InitHccl() {
uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
#if ENABLE_D
bool task_sink = true;
auto single_op = std::getenv(kGraphOpRun);
if (single_op && std::string(single_op) == "1") {
auto single_op = common::GetEnv(kGraphOpRun);
auto enable_mem_scheduler = common::GetEnv(kEnableMemScheduler);
if (single_op == "1" || enable_mem_scheduler == "1") {
task_sink = false;
}
auto mode = ms_context->get_param<int>(MS_CTX_EXECUTION_MODE);

View File

@ -1,6 +1,6 @@
file(GLOB_RECURSE DEVICE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "common/*.cc"
"kernel_info.cc" "executor/dynamic_kernel.cc" "executor/executor_callback.cc" "kernel_runtime.cc"
"memory_manager.cc" "kernel_runtime_manager.cc" "convert_tensor_utils.cc"
"memory_manager.cc" "kernel_runtime_manager.cc" "convert_tensor_utils.cc" "memory_scheduler.cc"
"bucket.cc" "launch_kernel.cc" "launch_mul.cc"
)

View File

@ -173,6 +173,10 @@ void AscendKernelRuntime::ClearGraphModelMap() {
void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id) {
SetCurrentContext();
auto mem_scheduler = mem_scheduler_manager_.GetMemScheduler(graph_id);
if (mem_scheduler != nullptr) {
mem_scheduler->Clear();
}
MS_LOG(DEBUG) << "Clear graph:" << graph_id << " data dumper";
#ifndef ENABLE_SECURITY
if (auto dumper_iter = graph_data_dumper_.find(graph_id); dumper_iter != graph_data_dumper_.end()) {
@ -725,25 +729,6 @@ bool AscendKernelRuntime::Run(session::KernelGraph *const graph, bool is_task_si
return ret;
}
bool AscendKernelRuntime::LaunchKernel(const AnfNodePtr &kernel) {
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
MS_EXCEPTION_IF_NULL(kernel_mod);
AddressPtrList kernel_inputs = kernel_mod->GetInputsAddr();
AddressPtrList kernel_workspaces = kernel_mod->GetWorkSpacesAddr();
AddressPtrList kernel_outputs = kernel_mod->GetOutputsAddr();
bool ret;
if (pynative_mode_profiling_flag_) {
auto ascend_kernel_mod = dynamic_cast<kernel::AscendKernelMod *>(kernel_mod);
MS_EXCEPTION_IF_NULL(ascend_kernel_mod);
auto stream = ascend_kernel_mod->GetStream();
ret = LaunchKernelWithPynativeProfiling(kernel_mod, kernel->fullname_with_scope(), kernel_inputs, kernel_workspaces,
kernel_outputs, stream);
} else {
ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
}
return ret;
}
void AscendKernelRuntime::SetKernelModStream(const std::vector<CNodePtr> &kernels,
std::vector<size_t> *last_stream_nodes) {
auto context_ptr = MsContext::GetInstance();

View File

@ -48,7 +48,6 @@ class AscendKernelRuntime : public KernelRuntime {
void ProcessBoundaryEvent(const std::vector<CNodePtr> &kernels,
std::vector<std::vector<std::function<void()>>> *kernel_run_events,
const std::vector<size_t> &last_stream_nodes);
bool LaunchKernel(const AnfNodePtr &kernel) override;
bool GenDynamicKernel(const session::KernelGraph *graph) override;
bool RunDynamicKernelAsync(const session::KernelGraph *graph) override;
bool LoadTask(const session::KernelGraph *graph);

View File

@ -128,6 +128,10 @@ void *AscendMemoryManager::MallocMemFromMemPool(size_t size) {
return AscendMemoryPool::GetInstance().AllocTensorMem(align_size);
}
void AscendMemoryManager::FreeMemFromMemPool(void *device_ptr) {
AscendMemoryPool::GetInstance().FreeTensorMem(device_ptr);
}
uint8_t *AscendMemoryManager::MallocStaticMem(size_t size, bool communication_mem, uint32_t graph_id) {
size_t align_size = 0;
if (communication_mem) {
@ -209,6 +213,47 @@ uint8_t *AscendMemoryManager::MallocCommunicationMemFromMemPool(size_t size) {
uint8_t *base_ptr = reinterpret_cast<uint8_t *>(AscendMemoryPool::GetInstance().AllocTensorMem(align_size));
return base_ptr + kMemAlignSize;
}
size_t AscendMemoryManager::GetAvailableMemSize() {
auto available_mem_size = AscendMemoryPool::GetInstance().free_mem_size() +
AscendMemoryPool::GetInstance().total_mem_statistics() -
AscendMemoryPool::GetInstance().used_mem_statistics();
return available_mem_size;
}
void AscendMemoryManager::SwapIn(void *host_ptr, void *device_ptr, size_t mem_size, void *stream) {
if (stream == nullptr) {
auto ret_rt_memcpy = rtMemcpy(device_ptr, mem_size, host_ptr, mem_size, RT_MEMCPY_HOST_TO_DEVICE);
if (ret_rt_memcpy != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "SwapIn rtMemcpy failed.";
}
} else {
auto ret_rt_memcpy = rtMemcpyAsync(device_ptr, mem_size, host_ptr, mem_size, RT_MEMCPY_HOST_TO_DEVICE, stream);
if (ret_rt_memcpy != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "SwapIn rtMemcpyAsync failed.";
}
if (rtStreamSynchronize(stream) != RT_ERROR_NONE) {
MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error.";
}
}
}
void AscendMemoryManager::SwapOut(void *device_ptr, void *host_ptr, size_t mem_size, void *stream) {
if (stream == nullptr) {
auto ret_rt_memcpy = rtMemcpy(host_ptr, mem_size, device_ptr, mem_size, RT_MEMCPY_DEVICE_TO_HOST);
if (ret_rt_memcpy != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "SwapOut rtMemcpy failed.";
}
} else {
auto ret_rt_memcpy = rtMemcpyAsync(host_ptr, mem_size, device_ptr, mem_size, RT_MEMCPY_DEVICE_TO_HOST, stream);
if (ret_rt_memcpy != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "SwapOut rtMemcpyAsync failed.";
}
if (rtStreamSynchronize(stream) != RT_ERROR_NONE) {
MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error.";
}
}
}
} // namespace ascend
} // namespace device
} // namespace mindspore

View File

@ -33,6 +33,7 @@ class AscendMemoryManager : public MemoryManager {
void ResetDynamicMemory() override;
void ClearGlobalIdleMem() override;
void *MallocMemFromMemPool(size_t size) override;
void FreeMemFromMemPool(void *device_ptr) override;
uint64_t GetDeviceMemSize();
void MallocSomasDynamicMem(const session::KernelGraph *graph) override;
uint8_t *MallocCommunicationMemFromMemPool(size_t size) override;
@ -40,6 +41,10 @@ class AscendMemoryManager : public MemoryManager {
return AscendMemoryPool::GetInstance().AllocContinuousTensorMem(total_size, size_list);
}
void SwapIn(void *host_ptr, void *device_ptr, size_t mem_size, void *stream) override;
void SwapOut(void *device_ptr, void *host_ptr, size_t mem_size, void *stream) override;
size_t GetAvailableMemSize() override;
protected:
uint8_t *MallocStaticMem(size_t size, bool communication_mem, uint32_t graph_id = kInvalidGraphId) override;
uint8_t *MallocDynamicMem(size_t size, bool communication_mem) override;

View File

@ -42,6 +42,9 @@ using mindspore::kernel::AddressPtr;
namespace mindspore {
namespace device {
constexpr float kMaxMemReuseFactor = 0.8;
constexpr float kMinMemReuseFactor = 0.5;
constexpr float kRetryFactor = 0.1;
namespace {
std::vector<AnfNodePtr> GetGraphInputs(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
@ -85,10 +88,16 @@ bool KernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_
void KernelRuntime::AssignMemory(session::KernelGraph *graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
MS_EXCEPTION_IF_NULL(mem_manager_);
mem_manager_->ResetDynamicMemory();
AssignStaticMemory(graph);
AssignDynamicMemory(graph);
auto enable_mem_scheduler = context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER);
if (enable_mem_scheduler) {
AssignStaticMemoryValueNode(graph);
ResetNodeAddress(graph);
} else {
MS_EXCEPTION_IF_NULL(mem_manager_);
mem_manager_->ResetDynamicMemory();
AssignStaticMemory(graph);
AssignDynamicMemory(graph);
}
UpdateRefNodeOutputMem(graph);
}
@ -253,6 +262,47 @@ void KernelRuntime::RunOpMallocPre(const session::KernelGraph &graph,
}
}
void KernelRuntime::ResetNodeAddress(session::KernelGraph *kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto kernels = kernel_graph->execution_order();
for (auto &kernel : kernels) {
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
MS_EXCEPTION_IF_NULL(kernel_mod);
size_t input_num = AnfAlgo::GetInputTensorNum(kernel);
for (size_t j = 0; j < input_num; ++j) {
auto input_index = AnfAlgo::GetRealInputIndex(kernel, j);
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(kernel, input_index, true);
auto index = kernel_with_index.second;
auto &input_node = kernel_with_index.first;
if (NodeOutputDeviceAddressExist(input_node, index)) {
continue;
}
TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(input_node, index);
if (output_type_id == kTypeUnknown) {
MS_LOG(WARNING) << "It is not suggested to use a lonely weight parameter as the output of graph";
continue;
}
auto tensor_size = AnfAlgo::GetOutputTensorMemSize(input_node, index);
auto device_address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(input_node, index),
output_type_id, {input_node, index});
AnfAlgo::SetOutputAddr(device_address, index, input_node.get());
}
auto output_sizes = kernel_mod->GetOutputSizeList();
for (size_t i = 0; i < output_sizes.size(); ++i) {
auto output_format = AnfAlgo::GetOutputFormat(kernel, i);
auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
AnfAlgo::SetOutputAddr(CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type), i,
kernel.get());
}
auto workspace_sizes = kernel_mod->GetWorkspaceSizeList();
for (size_t i = 0; i < workspace_sizes.size(); ++i) {
AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(nullptr, workspace_sizes[i], kOpFormat_DEFAULT, kNumberTypeFloat32),
i, kernel.get());
}
}
}
void KernelRuntime::RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors, session::KernelGraph *graph,
const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node) {
MS_EXCEPTION_IF_NULL(graph);
@ -1125,7 +1175,8 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod
}
}
void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs) {
void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs,
const std::shared_ptr<MemScheduler> &mem_scheduler) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(kernel_inputs);
if (cnode->inputs().size() != 2) {
@ -1144,8 +1195,12 @@ void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList
auto device_address = AnfAlgo::GetOutputAddr(pre_node, index);
kernel::AddressPtr input = std::make_shared<kernel::Address>();
MS_EXCEPTION_IF_NULL(input);
input->addr = device_address->ptr_;
MS_EXCEPTION_IF_NULL(input->addr);
if (mem_scheduler != nullptr) {
GetOrMallocAddress(mem_scheduler, device_address, input);
} else {
input->addr = device_address->ptr_;
MS_EXCEPTION_IF_NULL(input->addr);
}
input->size = device_address->size_;
kernel_inputs->emplace_back(input);
}
@ -1162,8 +1217,12 @@ void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList
auto device_address = AnfAlgo::GetWorkspaceAddr(pre_node, index);
kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
MS_EXCEPTION_IF_NULL(workspace);
workspace->addr = device_address->ptr_;
MS_EXCEPTION_IF_NULL(workspace->addr);
if (mem_scheduler != nullptr) {
GetOrMallocAddress(mem_scheduler, device_address, workspace);
} else {
workspace->addr = device_address->ptr_;
MS_EXCEPTION_IF_NULL(workspace->addr);
}
workspace->size = device_address->size_;
kernel_inputs->emplace_back(workspace);
}
@ -1220,36 +1279,197 @@ void KernelRuntime::DebugStreamSync(const CNodePtr &kernel) {
}
}
bool KernelRuntime::LaunchKernel(const AnfNodePtr &kernel) {
void KernelRuntime::GetOrMallocAddress(const std::shared_ptr<MemScheduler> &mem_scheduler,
const DeviceAddress *device_address, const kernel::AddressPtr &kernel_addr) {
if (device_address->ptr_ != nullptr) {
kernel_addr->addr = device_address->ptr_;
} else {
kernel_addr->addr = mem_scheduler->GetOrMalloc(device_address, device_address->size_);
if (mem_scheduler->IsHighPriorityMem(device_address)) {
device_address->ptr_ = kernel_addr->addr;
}
}
}
void KernelRuntime::AssignKernelAddress(const std::shared_ptr<MemScheduler> &mem_scheduler, const AnfNodePtr &kernel,
AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces,
AddressPtrList *kernel_outputs) {
MS_EXCEPTION_IF_NULL(kernel);
MS_EXCEPTION_IF_NULL(kernel_inputs);
MS_EXCEPTION_IF_NULL(kernel_workspaces);
MS_EXCEPTION_IF_NULL(kernel_outputs);
auto cnode = kernel->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetCNodeName(cnode) == kAtomicAddrCleanOpName) {
return GenAddrCleanLaunchArgs(cnode, kernel_inputs, mem_scheduler);
}
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
MS_EXCEPTION_IF_NULL(kernel_mod);
size_t input_num = AnfAlgo::GetInputTensorNum(kernel);
for (size_t j = 0; j < input_num; ++j) {
auto real_input = AnfAlgo::GetRealInputIndex(kernel, j);
auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(kernel, real_input, true);
auto index = kernel_with_index.second;
auto &input_node = kernel_with_index.first;
auto device_address = AnfAlgo::GetOutputAddr(input_node, index, true);
MS_EXCEPTION_IF_NULL(device_address);
kernel::AddressPtr input = std::make_shared<kernel::Address>();
GetOrMallocAddress(mem_scheduler, device_address, input);
input->size = device_address->size_;
kernel_inputs->emplace_back(input);
}
for (size_t j = 0; j < kernel_mod->GetOutputSizeList().size(); ++j) {
auto device_address = AnfAlgo::GetOutputAddr(kernel, j, true);
kernel::AddressPtr output = std::make_shared<kernel::Address>();
GetOrMallocAddress(mem_scheduler, device_address, output);
output->size = device_address->size_;
kernel_outputs->emplace_back(output);
}
for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) {
auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i);
kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
GetOrMallocAddress(mem_scheduler, device_address, workspace);
workspace->size = device_address->size_;
kernel_workspaces->emplace_back(workspace);
}
}
void KernelRuntime::SyncNodeOutputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler,
const session::KernelGraph *graph, const AnfNodePtr &kernel, bool mock) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(mem_scheduler);
MS_EXCEPTION_IF_NULL(kernel);
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
MS_EXCEPTION_IF_NULL(kernel_mod);
for (size_t j = 0; j < kernel_mod->GetOutputSizeList().size(); ++j) {
auto tensor = graph->GetNodeOutputTensor(std::make_pair(kernel, j));
auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, j, true);
if (mock) {
if (graph->IsInternalOutput(kernel, j) && device_address != nullptr) {
mem_scheduler->SetMemPriority(device_address.get(), kMemPriorityHigh);
}
continue;
}
if (tensor != nullptr) {
if (device_address == nullptr) {
tensor->data_sync(false);
tensor->set_device_address(nullptr);
tensor->set_sync_status(kNeedSyncHostToDevice);
continue;
}
SyncStream();
auto origin_ptr = device_address->ptr_;
if (origin_ptr == nullptr) {
device_address->ptr_ = mem_scheduler->GetOrMalloc(device_address.get(), device_address->size_);
}
tensor->set_device_address(device_address);
tensor->data_sync(false);
tensor->set_device_address(nullptr);
if (origin_ptr == nullptr) {
device_address->ptr_ = nullptr;
}
tensor->set_sync_status(kNeedSyncHostToDevice);
}
}
}
void KernelRuntime::InitGraphInputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler,
const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(mem_scheduler);
auto &input_nodes = graph->input_nodes();
auto &input_tensors = graph->input_tensors();
if (input_tensors.size() != input_nodes.size()) {
MS_LOG_EXCEPTION << "Invalid input tensor size:" << input_tensors.size() << " vs node size:" << input_nodes.size();
}
for (size_t i = 0; i < input_tensors.size(); ++i) {
auto tensor = input_tensors[i];
MS_EXCEPTION_IF_NULL(tensor);
auto input_node = input_nodes[i];
if (!input_node->isa<Parameter>()) {
continue;
}
auto input_param = input_node->cast<ParameterPtr>();
if (AnfAlgo::OutputAddrExist(input_node, 0)) {
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
MS_EXCEPTION_IF_NULL(tensor);
MemPriority priority = kMemPriorityHigh;
auto tensor_address = tensor->device_address();
if (tensor_address != nullptr && tensor_address != device_address) {
tensor->data_sync(false);
priority = kMemPriorityLow;
}
mem_scheduler->Init(device_address.get(), tensor->data_c(), tensor->data().nbytes(), priority);
}
}
}
bool KernelRuntime::LaunchKernel(const session::KernelGraph *graph, const AnfNodePtr &kernel,
const std::shared_ptr<MemScheduler> &mem_scheduler, bool mock) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(kernel);
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
MS_EXCEPTION_IF_NULL(kernel_mod);
AddressPtrList kernel_inputs;
AddressPtrList kernel_workspaces;
AddressPtrList kernel_outputs;
GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
bool ret;
if (AnfAlgo::IsCommunicationOp(kernel)) {
if (pynative_mode_profiling_flag_) {
ret = LaunchKernelWithPynativeProfiling(kernel_mod, kernel->fullname_with_scope(), kernel_inputs,
kernel_workspaces, kernel_outputs, communication_stream_);
auto stream = kernel_mod->GetStream();
if (stream == nullptr) {
if (AnfAlgo::IsCommunicationOp(kernel)) {
stream = communication_stream_;
} else {
ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, communication_stream_);
stream = stream_;
}
}
bool ret = true;
if (mem_scheduler != nullptr) {
ret = mem_scheduler->PreCompute(stream);
if (!ret) {
return ret;
}
AssignKernelAddress(mem_scheduler, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
} else if (!kernel_mod->GetInputsAddr().empty() || !kernel_mod->GetOutputsAddr().empty()) {
kernel_inputs = kernel_mod->GetInputsAddr();
kernel_outputs = kernel_mod->GetOutputsAddr();
kernel_workspaces = kernel_mod->GetWorkSpacesAddr();
} else {
GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
}
if (!mock) {
if (pynative_mode_profiling_flag_) {
ret = LaunchKernelWithPynativeProfiling(kernel_mod, kernel->fullname_with_scope(), kernel_inputs,
kernel_workspaces, kernel_outputs, stream_);
kernel_workspaces, kernel_outputs, stream);
} else {
ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream);
}
}
if (mem_scheduler != nullptr) {
SyncNodeOutputTensors(mem_scheduler, graph, kernel, mock);
ret = mem_scheduler->PostCompute(stream);
if (!ret) {
return ret;
}
}
return ret;
}
bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
const auto &kernels = graph.execution_order();
bool KernelRuntime::LaunchKernelMod(const session::KernelGraph *graph, bool mock) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
std::shared_ptr<MemScheduler> mem_scheduler = nullptr;
auto enable_mem_scheduler = context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER);
if (enable_mem_scheduler) {
mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph->graph_id());
MS_EXCEPTION_IF_NULL(mem_scheduler);
mem_scheduler->SetMemHandler(mem_manager_);
mem_scheduler->RecordMemUsage();
InitGraphInputTensors(mem_scheduler, graph);
}
const auto &kernels = graph->execution_order();
std::vector<DynamicKernelPtr> dynamic_kernel_list;
auto iter = graph_dynamic_kernel_map_.find(graph.graph_id());
auto iter = graph_dynamic_kernel_map_.find(graph->graph_id());
if (iter != graph_dynamic_kernel_map_.end()) {
dynamic_kernel_list = iter->second;
}
@ -1259,7 +1479,7 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
}
std::vector<std::vector<std::function<void()>>> kernel_pre_run_events;
std::vector<std::vector<std::function<void()>>> kernel_post_run_events;
auto events_iter = graph_kernel_events_map_.find(graph.graph_id());
auto events_iter = graph_kernel_events_map_.find(graph->graph_id());
if (events_iter != graph_kernel_events_map_.end()) {
kernel_pre_run_events = events_iter->second.first;
kernel_post_run_events = events_iter->second.second;
@ -1271,7 +1491,6 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
dynamic_kernel_list[i]->InferShape();
dynamic_kernel_list[i]->UpdateArgs();
dynamic_kernel_list[i]->Execute();
if (!SyncStream()) {
MS_LOG(ERROR) << "SyncStream failed";
return false;
@ -1292,7 +1511,7 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
}
continue;
}
auto ret = LaunchKernel(kernel);
auto ret = LaunchKernel(graph, kernel, mem_scheduler, mock);
if (!ret) {
MS_LOG(ERROR) << "Launch kernel failed.";
return false;
@ -1302,16 +1521,42 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
}
LaunchKernelEvent(kernel_post_run_events, i);
}
if (mem_scheduler != nullptr) {
mem_scheduler->OptMemUsage();
}
return true;
}
void KernelRuntime::UseMemSchedulerIfNeeded(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
auto enable_mem_scheduler = context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER);
if (enable_mem_scheduler) {
auto mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph->graph_id());
if (mem_scheduler->need_record_event()) {
(void)LaunchKernelMod(graph, true);
}
float mem_used_factor = kMaxMemReuseFactor;
while (!mem_scheduler->optimized() && mem_used_factor >= kMinMemReuseFactor) {
mem_scheduler->SetMemUsedFactor(mem_used_factor);
bool ret = LaunchKernelMod(graph, true);
if (ret) {
mem_scheduler->SetOptimized(true);
} else {
mem_used_factor -= kRetryFactor;
}
}
}
}
bool KernelRuntime::LaunchKernels(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
if (!LaunchKernelMod(*graph)) {
UseMemSchedulerIfNeeded(graph);
if (!LaunchKernelMod(graph)) {
MS_LOG(ERROR) << "LaunchKernelMod failed!";
return false;
}
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {

View File

@ -33,6 +33,7 @@
#include "backend/kernel_compiler/kernel.h"
#include "utils/ms_context.h"
#include "runtime/device/memory_manager.h"
#include "runtime/device/memory_scheduler.h"
#include "runtime/device/executor/dynamic_kernel.h"
#include "ir/device_event.h"
@ -69,7 +70,6 @@ class KernelRuntime {
virtual bool GenDynamicKernel(const session::KernelGraph *graph) = 0;
virtual bool RunDynamicKernelAsync(const session::KernelGraph *graph) = 0;
bool LaunchKernels(const session::KernelGraph *graph);
virtual bool LaunchKernel(const AnfNodePtr &kernel);
virtual void AssignStaticMemoryInput(const session::KernelGraph *graph);
virtual void AssignStaticMemoryValueNode(session::KernelGraph *graph);
virtual void ClearGraphRuntimeResource(uint32_t graph_id);
@ -141,11 +141,24 @@ class KernelRuntime {
virtual void KernelLaunchProfiling(const std::string &kernel_name) {}
private:
void UseMemSchedulerIfNeeded(const session::KernelGraph *graph);
bool LaunchKernel(const session::KernelGraph *graph, const AnfNodePtr &kernel,
const std::shared_ptr<MemScheduler> &mem_scheduler, bool mock = false);
void ResetNodeAddress(session::KernelGraph *graph);
void AssignKernelAddress(const std::shared_ptr<MemScheduler> &mem_scheduler, const AnfNodePtr &kernel,
AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces,
AddressPtrList *kernel_outputs);
static void GetOrMallocAddress(const std::shared_ptr<MemScheduler> &mem_scheduler,
const DeviceAddress *device_address, const kernel::AddressPtr &kernel_addr);
void InitGraphInputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler, const session::KernelGraph *graph);
void SyncNodeOutputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler, const session::KernelGraph *graph,
const AnfNodePtr &kernel, bool mock);
void AssignStaticMemoryOutput(const session::KernelGraph *graph);
bool LaunchKernelMod(const session::KernelGraph &graph);
bool LaunchKernelMod(const session::KernelGraph *graph, bool mock = false);
void LaunchKernelEvent(const std::vector<std::vector<std::function<void()>>> &run_events, size_t index) const;
void DebugStreamSync(const CNodePtr &kernel);
static void GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs);
static void GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs,
const std::shared_ptr<MemScheduler> &mem_schedule = nullptr);
void RunOpAssignInputMemory(const std::vector<tensor::TensorPtr> &input_tensors, const session::KernelGraph *graph);
void RunOpAssignOutputMemory(const AnfNodePtr &kernel,
const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node = {});
@ -179,6 +192,7 @@ class KernelRuntime {
std::map<uint32_t,
std::pair<std::vector<std::vector<std::function<void()>>>, std::vector<std::vector<std::function<void()>>>>>
graph_kernel_events_map_;
MemSchedulerManager mem_scheduler_manager_;
};
using KernelRuntimePtr = std::shared_ptr<KernelRuntime>;
} // namespace device

View File

@ -19,8 +19,11 @@
#include <memory>
#include <utility>
#include <vector>
#include <map>
#include <queue>
#include "backend/optimizer/mem_reuse/mem_reuse.h"
#include "backend/optimizer/somas/somas.h"
#include "runtime/device/memory_scheduler.h"
namespace mindspore {
namespace device {
enum MemType { kStaticMem, kDynamicMem, kSomasReuseDynamicMem };
@ -28,7 +31,7 @@ const int kGetAllOuts = -1;
const uint64_t kMemAlignSize = 512;
using SomasPtr = mindspore::somas::SomasPtr;
class MemoryManager {
class MemoryManager : public MemHandler {
public:
MemoryManager() = default;
virtual ~MemoryManager() = default;
@ -60,6 +63,45 @@ class MemoryManager {
static size_t GetCommonAlignSize(size_t input_size);
static size_t GetCommunicationAlignSize(size_t input_size);
// swap manager interface
void *MallocDevice(size_t mem_size) override { return MallocMemFromMemPool(mem_size); }
void FreeDevice(void *ptr) override {
MS_EXCEPTION_IF_NULL(ptr);
FreeMemFromMemPool(ptr);
}
void *MallocHost(size_t mem_size) override {
auto &mem_que = cached_host_mem_[mem_size];
if (!mem_que.empty()) {
auto ret = mem_que.front();
mem_que.pop();
return ret;
}
auto block = std::make_shared<std::vector<uint8_t>>();
try {
block->resize(mem_size, 0);
auto ptr = block->data();
host_mem_block_map_[ptr] = block;
return ptr;
} catch (const std::exception &e) {
MS_LOG(EXCEPTION) << "Malloc memory failed: size " << mem_size;
}
}
void FreeHost(void *ptr) override {
MS_EXCEPTION_IF_NULL(ptr);
auto iter = host_mem_block_map_.find(ptr);
if (iter == host_mem_block_map_.end()) {
MS_LOG(ERROR) << "Free ptr not be created from manager!";
}
auto mem_size = iter->second->size();
cached_host_mem_[mem_size].emplace(iter->first);
}
void SwapIn(void *host_ptr, void *device_ptr, size_t mem_size, void *stream) override {}
void SwapOut(void *device_ptr, void *host_ptr, size_t mem_size, void *stream) override {}
size_t GetAvailableMemSize() override {
MS_LOG(ERROR) << "Return default 0 mem size!";
return 0;
}
protected:
virtual uint8_t *MallocStaticMem(size_t size, bool communication_mem, uint32_t graph_id = kInvalidGraphId) = 0;
virtual uint8_t *MallocDynamicMem(size_t size, bool communication_mem);
@ -70,6 +112,8 @@ class MemoryManager {
size_t total_static_size_ = 0;
size_t total_dynamic_size_ = 0;
SomasPtr somas_reuse_util_ptr_{nullptr};
std::map<size_t, std::queue<void *>> cached_host_mem_;
std::map<void *, std::shared_ptr<std::vector<uint8_t>>> host_mem_block_map_;
};
} // namespace device
} // namespace mindspore

View File

@ -0,0 +1,377 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/memory_scheduler.h"
#include <map>
#include <set>
#include <memory>
#include <utility>
#include <algorithm>
#include "utils/log_adapter.h"
namespace mindspore {
namespace device {
void MemScheduler::Clear() {
if (mem_handler_ == nullptr) {
return;
}
for (auto &item : high_priority_device_ptr_) {
mem_handler_->FreeDevice(item.second);
}
high_priority_device_ptr_.clear();
}
bool MemScheduler::IsHighPriorityMem(const void *key) {
auto iter = mem_priority_.find(key);
if (iter != mem_priority_.end()) {
return iter->second == kMemPriorityHigh;
}
return false;
}
void MemScheduler::SetMemPriority(const void *key, MemPriority priority) { mem_priority_[key] = priority; }
void MemScheduler::Record(const void *key, const EventType &event_type, size_t mem_size) {
if (key == nullptr) {
return;
}
auto event = std::make_shared<Event>(event_type, compute_index_);
event->mem_size = mem_size;
event->key = key;
mem_events_[key].emplace_back(event);
}
void MemScheduler::Init(const void *key, void *host_ptr, size_t mem_size, MemPriority priority) {
if (need_record_event_) {
mem_priority_[key] = priority;
Record(key, kInit, mem_size);
} else {
init_host_ptr_[key] = host_ptr;
}
}
void *MemScheduler::GetOrMalloc(const void *key, size_t mem_size, MemPriority priority) {
if (need_record_event_) {
if (mem_priority_.find(key) == mem_priority_.end()) {
mem_priority_[key] = priority;
Record(key, kMalloc, mem_size);
} else {
Record(key, kGet, mem_size);
}
return nullptr;
}
auto iter = mem_result_.find(key);
if (iter != mem_result_.end()) {
auto ptr = iter->second;
MS_EXCEPTION_IF_NULL(ptr);
return ptr;
} else {
MS_LOG_EXCEPTION << "Mem extender get nullptr result!";
}
}
bool MemScheduler::PreCompute(void *stream) {
if (need_record_event_) {
return true;
}
MS_EXCEPTION_IF_NULL(mem_handler_);
if (pre_compute_events_.size() <= compute_index_) {
MS_LOG_EXCEPTION << "Index out of pre event range, index:" << compute_index_
<< ", event size:" << pre_compute_events_.size();
}
auto &events = pre_compute_events_[compute_index_];
for (auto &event : events) {
MS_EXCEPTION_IF_NULL(event);
MS_LOG(DEBUG) << "Pre compute " << compute_index_ << ": " << event->key << " v " << event->type;
if (event->type == kInit) {
auto host_ptr = init_host_ptr_[event->key];
MS_EXCEPTION_IF_NULL(host_ptr);
auto priority = mem_priority_[event->key];
auto iter = high_priority_device_ptr_.find(event->key);
if (priority != kMemPriorityLow && iter != high_priority_device_ptr_.end()) {
MS_EXCEPTION_IF_NULL(iter->second);
mem_result_[event->key] = iter->second;
if (priority == kMemPriorityMedium) {
mem_handler_->SwapIn(host_ptr, iter->second, event->mem_size, stream);
}
continue;
}
auto device_ptr = mem_handler_->MallocDevice(event->mem_size);
if (device_ptr == nullptr) {
return false;
}
if (priority != kMemPriorityLow) {
high_priority_device_ptr_[event->key] = device_ptr;
}
mem_handler_->SwapIn(host_ptr, device_ptr, event->mem_size, stream);
mem_result_[event->key] = device_ptr;
} else if (event->type == kMalloc) {
auto device_ptr = mem_handler_->MallocDevice(event->mem_size);
if (device_ptr == nullptr) {
return false;
}
mem_result_[event->key] = device_ptr;
} else if (event->type == kSwapIn) {
bool from_init = true;
auto host_ptr = init_host_ptr_[event->key];
if (host_ptr == nullptr) {
host_ptr = swap_host_ptr_[event->key];
from_init = false;
}
auto device_ptr = mem_handler_->MallocDevice(event->mem_size);
if (device_ptr == nullptr) {
return false;
}
MS_EXCEPTION_IF_NULL(host_ptr);
mem_handler_->SwapIn(host_ptr, device_ptr, event->mem_size, stream);
mem_result_[event->key] = device_ptr;
if (!from_init) {
mem_handler_->FreeHost(host_ptr);
swap_host_ptr_.erase(event->key);
}
}
}
return true;
}
bool MemScheduler::PostCompute(void *stream) {
if (need_record_event_) {
++compute_index_;
return true;
}
MS_EXCEPTION_IF_NULL(mem_handler_);
if (post_compute_events_.size() <= compute_index_) {
MS_LOG_EXCEPTION << "Index out of post event range, index:" << compute_index_
<< ", event size:" << post_compute_events_.size();
}
auto &events = post_compute_events_[compute_index_];
for (auto &event : events) {
MS_EXCEPTION_IF_NULL(event);
MS_LOG(DEBUG) << "Post compute " << compute_index_ << ": " << event->key << " v " << event->type;
if (event->type == kFree) {
auto ptr = mem_result_[event->key];
if (ptr == nullptr) {
return false;
}
mem_handler_->FreeDevice(ptr);
(void)mem_result_.erase(event->key);
} else if (event->type == kSwapOut) {
auto device_ptr = mem_result_[event->key];
if (device_ptr == nullptr) {
return false;
}
auto host_ptr = init_host_ptr_[event->key];
if (host_ptr == nullptr) {
host_ptr = mem_handler_->MallocHost(event->mem_size);
swap_host_ptr_[event->key] = host_ptr;
}
MS_EXCEPTION_IF_NULL(host_ptr);
mem_handler_->SwapOut(device_ptr, host_ptr, event->mem_size, stream);
mem_handler_->FreeDevice(device_ptr);
mem_result_.erase(device_ptr);
}
}
++compute_index_;
return true;
}
void MemScheduler::OptMemUsage() {
need_record_event_ = false;
if (optimized_) {
return;
}
CountMemUsage();
CheckMemSize();
if (need_swap_) {
GenEventSpan();
GenNoSwapEventSet();
}
GenEvents();
}
void MemScheduler::CountMemUsage() {
if (!min_mem_used_.empty()) {
return;
}
min_mem_used_.resize(compute_index_, 0);
std::vector<size_t> total_mem_used(compute_index_, 0);
for (auto &item : mem_events_) {
auto &mem_events = item.second;
if (mem_events.empty()) {
continue;
}
auto first_event = mem_events[0];
MS_EXCEPTION_IF_NULL(first_event);
size_t i = 0;
if (first_event->type == kInit && mem_events.size() > 1) {
first_event = mem_events[1];
i = 1;
}
auto last_event = mem_events[mem_events.size() - 1];
for (size_t start_index = first_event->index; start_index <= last_event->index; ++start_index) {
if (start_index < compute_index_) {
total_mem_used[start_index] += first_event->mem_size;
} else {
MS_LOG(ERROR) << "Error mem event index " << start_index;
}
}
for (; i < mem_events.size(); ++i) {
auto &event = mem_events[i];
MS_EXCEPTION_IF_NULL(event);
if (event->index < compute_index_) {
min_mem_used_[event->index] += first_event->mem_size;
} else {
MS_LOG(ERROR) << "Error mem event index " << event->index;
}
}
}
min_mem_needed_ = *(std::max_element(min_mem_used_.begin(), min_mem_used_.end()));
mem_used_without_swap_ = *(std::max_element(total_mem_used.begin(), total_mem_used.end()));
}
void MemScheduler::CheckMemSize() {
MS_EXCEPTION_IF_NULL(mem_handler_);
auto available_mem_size = mem_handler_->GetAvailableMemSize();
if (available_mem_size < min_mem_needed_) {
MS_LOG(EXCEPTION) << "Out of memory, as available mem size is " << available_mem_size
<< " while graph needs at least " << min_mem_needed_;
}
if (mem_used_without_swap_ > available_mem_size) {
need_swap_ = true;
}
MS_LOG(INFO) << "Available mem size: " << available_mem_size << ", graph needs mem size:" << mem_used_without_swap_
<< "without swap, and needs at least " << min_mem_needed_ << " with swap.";
}
void MemScheduler::GenEventSpan() {
if (!event_span_.empty()) {
return;
}
for (auto &item : mem_events_) {
auto &mem_events = item.second;
if (mem_events.empty()) {
continue;
}
auto first_event = mem_events[0];
MS_EXCEPTION_IF_NULL(first_event);
size_t i = 0;
if (first_event->type == kInit && mem_events.size() > 1) {
first_event = mem_events[1];
i = 1;
}
size_t last_index = first_event->index;
for (; i < mem_events.size(); ++i) {
auto &event = mem_events[i];
MS_EXCEPTION_IF_NULL(event);
auto span = event->index - last_index;
if (span > 1) {
(void)event_span_.insert(std::pair<size_t, std::shared_ptr<Event>>(span, event));
}
last_index = event->index;
}
}
}
void MemScheduler::GenNoSwapEventSet() {
MS_EXCEPTION_IF_NULL(mem_handler_);
auto available_mem_size = mem_handler_->GetAvailableMemSize();
auto threshold = available_mem_size * mem_used_factor_;
no_swap_events_.clear();
std::vector<size_t> cur_mem_used(min_mem_used_.begin(), min_mem_used_.end());
for (auto iter = event_span_.begin(); iter != event_span_.end(); ++iter) {
auto span = iter->first;
auto &event = iter->second;
auto start_index = event->index - span + 1;
bool revert = false;
for (size_t i = start_index; i < event->index; ++i) {
cur_mem_used[i] += event->mem_size;
if (cur_mem_used[i] > threshold) {
revert = true;
}
}
if (revert) {
for (size_t i = start_index; i < event->index; ++i) {
cur_mem_used[i] -= event->mem_size;
}
} else {
no_swap_events_.emplace(event);
}
}
}
void MemScheduler::GenEvents() {
pre_compute_events_.resize(compute_index_);
post_compute_events_.resize(compute_index_);
for (auto &item : mem_events_) {
auto &mem_events = item.second;
if (mem_events.empty()) {
continue;
}
auto first_event = mem_events[0];
MS_EXCEPTION_IF_NULL(first_event);
if (first_event->type == kInit) {
if (mem_events.size() > 1) {
auto &second_event = mem_events[1];
MS_EXCEPTION_IF_NULL(second_event);
first_event->index = second_event->index;
} else {
continue;
}
}
if ((first_event->type == kInit || first_event->type == kMalloc) &&
first_event->index < pre_compute_events_.size()) {
pre_compute_events_[first_event->index].emplace_back(first_event);
} else {
MS_LOG_EXCEPTION << "First event should be init or malloc!";
}
MemPriority priority = kMemPriorityLow;
auto iter = mem_priority_.find(first_event->key);
if (iter != mem_priority_.end()) {
priority = iter->second;
}
size_t pre_index = first_event->index;
for (size_t i = 1; i < mem_events.size(); ++i) {
auto &event = mem_events[i];
MS_EXCEPTION_IF_NULL(event);
if (need_swap_ && event->index - pre_index > 1 && priority == kMemPriorityLow &&
no_swap_events_.find(event) == no_swap_events_.end()) {
auto swap_out_event = std::make_shared<Event>(kSwapOut, pre_index);
swap_out_event->key = item.first;
swap_out_event->mem_size = first_event->mem_size;
post_compute_events_[pre_index].emplace_back(swap_out_event);
auto swap_in_event = std::make_shared<Event>(kSwapIn, event->index);
swap_in_event->key = item.first;
swap_in_event->mem_size = first_event->mem_size;
pre_compute_events_[event->index].emplace_back(swap_in_event);
}
if (event->index < pre_compute_events_.size()) {
pre_compute_events_[event->index].emplace_back(event);
}
pre_index = event->index;
}
if (priority != kMemPriorityLow) {
continue;
}
auto &last_event = mem_events[mem_events.size() - 1];
MS_EXCEPTION_IF_NULL(last_event);
auto free_event = std::make_shared<Event>(kFree, last_event->index);
free_event->key = item.first;
if (last_event->index < post_compute_events_.size()) {
post_compute_events_[last_event->index].emplace_back(free_event);
}
}
}
} // namespace device
} // namespace mindspore

View File

@ -0,0 +1,143 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_SCHEDULER_H_
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_SCHEDULER_H_
#include <vector>
#include <map>
#include <set>
#include <memory>
#include <utility>
namespace mindspore {
namespace device {
class MemHandler {
public:
virtual size_t GetAvailableMemSize() = 0;
virtual void *MallocDevice(size_t mem_size) = 0;
virtual void FreeDevice(void *ptr) = 0;
virtual void *MallocHost(size_t mem_size) = 0;
virtual void FreeHost(void *ptr) = 0;
virtual void SwapIn(void *host_ptr, void *device_ptr, size_t mem_size, void *stream) = 0;
virtual void SwapOut(void *device_ptr, void *host_ptr, size_t mem_size, void *stream) = 0;
};
enum MemPriority { kMemPriorityLow, kMemPriorityMedium, kMemPriorityHigh };
class MemScheduler {
enum EventType { kInit, kMalloc, kGet, kFree, kSwapIn, kSwapOut };
struct Event {
Event(const EventType &in_type, size_t in_index) {
type = in_type;
index = in_index;
}
EventType type;
size_t index{0};
size_t mem_size{0};
const void *key{nullptr};
};
public:
MemScheduler() = default;
~MemScheduler() = default;
bool need_record_event() const { return need_record_event_; }
bool optimized() const { return optimized_; }
void SetOptimized(bool flag) { optimized_ = flag; }
void SetMemHandler(const std::shared_ptr<MemHandler> &handler) { mem_handler_ = handler; }
void Init(const void *key, void *host_ptr, size_t mem_size, MemPriority priority = kMemPriorityLow);
void *GetOrMalloc(const void *key, size_t mem_size, MemPriority priority = kMemPriorityLow);
void RecordMemUsage() { compute_index_ = 0; }
bool PreCompute(void *stream);
bool PostCompute(void *stream);
void OptMemUsage();
void Clear();
bool IsHighPriorityMem(const void *key);
void SetMemPriority(const void *key, MemPriority priority);
void SetMemUsedFactor(float factor) { mem_used_factor_ = factor; }
void SetNeedSwap(bool flag) { need_swap_ = flag; }
private:
void Record(const void *key, const EventType &event_type, size_t mem_size = 0);
void GenEvents();
void CheckMemSize();
void CountMemUsage();
void GenEventSpan();
void GenNoSwapEventSet();
std::map<const void *, MemPriority> mem_priority_;
std::map<const void *, std::vector<std::shared_ptr<Event>>> mem_events_;
std::vector<std::vector<std::shared_ptr<Event>>> pre_compute_events_;
std::vector<std::vector<std::shared_ptr<Event>>> post_compute_events_;
std::map<const void *, void *> mem_result_;
std::map<const void *, void *> init_host_ptr_;
std::map<const void *, void *> swap_host_ptr_;
std::map<const void *, void *> high_priority_device_ptr_;
size_t compute_index_{0};
bool need_record_event_{true};
bool optimized_{false};
std::shared_ptr<MemHandler> mem_handler_{nullptr};
bool need_swap_{false};
std::multimap<size_t, std::shared_ptr<Event>> event_span_;
std::set<std::shared_ptr<Event>> no_swap_events_;
std::vector<size_t> min_mem_used_;
size_t mem_used_without_swap_{0};
size_t min_mem_needed_{0};
float mem_used_factor_{0.9};
};
class MemSchedulerManager {
public:
MemSchedulerManager() = default;
~MemSchedulerManager() = default;
std::shared_ptr<MemScheduler> GetOrCreateMemScheduler(uint64_t uid) {
auto scheduler = GetMemScheduler(uid);
if (scheduler == nullptr) {
scheduler = std::make_shared<MemScheduler>();
graph_mem_scheduler_map_[uid] = scheduler;
}
return scheduler;
}
std::shared_ptr<MemScheduler> GetMemScheduler(uint64_t uid) {
auto iter = graph_mem_scheduler_map_.find(uid);
if (iter != graph_mem_scheduler_map_.end()) {
return iter->second;
}
return nullptr;
}
private:
std::map<uint64_t, std::shared_ptr<MemScheduler>> graph_mem_scheduler_map_;
};
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_SCHEDULER_H_

View File

@ -481,6 +481,7 @@ constexpr auto kValueTargetOther = "target_other";
// env key
constexpr auto kGraphOpRun = "GRAPH_OP_RUN";
constexpr auto kEnableMemScheduler = "ENABLE_MEM_SCHEDULER";
// some size
const size_t kShape4dDims = 4;

View File

@ -602,8 +602,13 @@ BackendPtr CreateBackend() {
backend->set_is_multi_graph_sink(false);
context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
} else {
auto single_op = std::getenv(kGraphOpRun);
if (single_op && std::string(single_op) == "1") {
auto single_op = common::GetEnv(kGraphOpRun);
if (single_op == "1") {
context_ptr->set_param<bool>(MS_CTX_ENABLE_TASK_SINK, false);
}
auto enable_mem_scheduler = common::GetEnv(kEnableMemScheduler);
if (enable_mem_scheduler == "1") {
context_ptr->set_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER, true);
context_ptr->set_param<bool>(MS_CTX_ENABLE_TASK_SINK, false);
}
}

View File

@ -81,6 +81,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
set_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE, "0");
set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, target == kAscendDevice || target == kDavinciDevice);
set_param<bool>(MS_CTX_ENABLE_PROFILING, false);
set_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER, false);
set_param<std::string>(MS_CTX_PROFILING_OPTIONS, "training_trace");
set_param<bool>(MS_CTX_CHECK_BPROP_FLAG, false);
set_param<float>(MS_CTX_MAX_DEVICE_MEMORY, kDefaultMaxDeviceMemory);

View File

@ -72,6 +72,7 @@ enum MsCtxParam : unsigned {
MS_CTX_ENABLE_GRAPH_KERNEL,
MS_CTX_ENABLE_HCCL,
MS_CTX_ENABLE_LOOP_SINK,
MS_CTX_ENABLE_MEM_SCHEDULER,
MS_CTX_ENABLE_PYNATIVE_HOOK,
MS_CTX_ENABLE_PYNATIVE_INFER,
MS_CTX_ENABLE_REDUCE_PRECISION,