forked from mindspore-Ecosystem/mindspore
DeviceAddress::MoveTo and Add SwapManager
This commit is contained in:
parent
0101b59ec2
commit
8569480db6
|
@ -3,7 +3,7 @@ file(GLOB_RECURSE DEVICE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "common/*
|
|||
"memory_manager.cc" "kernel_runtime_manager.cc" "convert_tensor_utils.cc" "memory_scheduler.cc"
|
||||
"memory_offload_strategy.cc" "launch_kernel.cc" "launch_mul.cc" "tensor_array.cc"
|
||||
"ms_device_shape_transfer.cc" "context_extends.cc" "stream_synchronizer.cc" "tensors_queue.cc" "auto_mem_offload.cc"
|
||||
"common_somas_allocator.cc" "device_address_utils.cc"
|
||||
"common_somas_allocator.cc" "device_address_utils.cc" "loadable_device_address.cc"
|
||||
)
|
||||
|
||||
if("${ENABLE_HIDDEN}" STREQUAL "OFF" AND NOT MSVC)
|
||||
|
|
|
@ -217,7 +217,7 @@ class DeviceAddress : public mindspore::DeviceSync {
|
|||
virtual bool Load(size_t) { MS_LOG(EXCEPTION) << "Not implemented."; }
|
||||
|
||||
// Move data to destination hardware and free resource on source hardware
|
||||
virtual bool MoveTo(StorageType, bool) { MS_LOG(EXCEPTION) << "Not implemented."; }
|
||||
virtual bool MoveTo(StorageType, bool, size_t) { MS_LOG(EXCEPTION) << "Not implemented."; }
|
||||
|
||||
virtual bool Wait() const { MS_LOG(EXCEPTION) << "Not implemented."; }
|
||||
|
||||
|
@ -230,6 +230,10 @@ class DeviceAddress : public mindspore::DeviceSync {
|
|||
virtual void SetStorageInfo(const StorageInfo &) {}
|
||||
|
||||
virtual void HandOver(DeviceAddress *other) {
|
||||
MS_EXCEPTION_IF_NULL(other);
|
||||
if (other == this) {
|
||||
return;
|
||||
}
|
||||
other->set_ptr(GetMutablePtr());
|
||||
other->set_from_mem_pool(from_mem_pool());
|
||||
set_ptr(nullptr);
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GSM_HOST_MEM_POOL_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GSM_HOST_MEM_POOL_H_
|
||||
|
||||
#include <memory>
|
||||
#include "include/backend/visible.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
class HostMemPool {
|
||||
public:
|
||||
explicit BACKEND_EXPORT HostMemPool(size_t max_size) : max_size_(max_size) {}
|
||||
~HostMemPool() = default;
|
||||
virtual void *Malloc(size_t size) { return nullptr; }
|
||||
virtual void Free(const void *ptr) {}
|
||||
|
||||
protected:
|
||||
size_t max_size_;
|
||||
};
|
||||
using HostMemPoolPtr = std::shared_ptr<HostMemPool>;
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_GSM_HOST_MEM_POOL_H_
|
|
@ -0,0 +1,91 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/device/gsm/io_handle.h"
|
||||
#include <memory>
|
||||
#include "utils/system/env.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
constexpr size_t kFileHeadOffset = 0;
|
||||
|
||||
bool IOHandle::Read(const std::string &file_name, void *data, size_t byte_num) {
|
||||
if (aio_ != nullptr) {
|
||||
return aio_->Read(file_name, data, byte_num);
|
||||
}
|
||||
const auto &fs = system::Env::GetFileSystem();
|
||||
MS_EXCEPTION_IF_NULL(fs);
|
||||
auto file = fs->CreateWriteFile(GetSwapFileWholeName(file_name));
|
||||
MS_EXCEPTION_IF_NULL(file);
|
||||
return file->PRead(data, byte_num, kFileHeadOffset);
|
||||
}
|
||||
|
||||
bool IOHandle::Write(const std::string &file_name, const void *data, size_t byte_num) {
|
||||
if (aio_ != nullptr) {
|
||||
return aio_->Write(file_name, data, byte_num);
|
||||
}
|
||||
const auto &fs = system::Env::GetFileSystem();
|
||||
MS_EXCEPTION_IF_NULL(fs);
|
||||
auto file = fs->CreateWriteFile(GetSwapFileWholeName(file_name));
|
||||
MS_EXCEPTION_IF_NULL(file);
|
||||
return file->PWrite(data, byte_num, kFileHeadOffset);
|
||||
}
|
||||
|
||||
bool IOHandle::ReadAsync(const std::string &file_name, void *data, size_t byte_num, AsyncIOToken *token) {
|
||||
if (aio_ != nullptr) {
|
||||
return aio_->ReadAsync(file_name, data, byte_num, token);
|
||||
}
|
||||
const auto &fs = system::Env::GetFileSystem();
|
||||
MS_EXCEPTION_IF_NULL(fs);
|
||||
auto file = fs->CreateWriteFile(GetSwapFileWholeName(file_name));
|
||||
MS_EXCEPTION_IF_NULL(file);
|
||||
*token = kInvalidAsyncIOToken;
|
||||
return file->PRead(data, byte_num, kFileHeadOffset);
|
||||
}
|
||||
|
||||
bool IOHandle::WriteAsync(const std::string &file_name, const void *data, size_t byte_num, AsyncIOToken *token) {
|
||||
if (aio_ != nullptr) {
|
||||
return aio_->WriteAsync(file_name, data, byte_num, token);
|
||||
}
|
||||
const auto &fs = system::Env::GetFileSystem();
|
||||
MS_EXCEPTION_IF_NULL(fs);
|
||||
auto file = fs->CreateWriteFile(GetSwapFileWholeName(file_name));
|
||||
MS_EXCEPTION_IF_NULL(file);
|
||||
*token = kInvalidAsyncIOToken;
|
||||
return file->PWrite(data, byte_num, kFileHeadOffset);
|
||||
}
|
||||
|
||||
bool IOHandle::Wait(AsyncIOToken token) { return aio_ == nullptr || aio_->Wait(token); }
|
||||
|
||||
bool IOHandle::DeleteSwapFile(const std::string &file_name) const {
|
||||
const auto &fs = system::Env::GetFileSystem();
|
||||
MS_EXCEPTION_IF_NULL(fs);
|
||||
return fs->DeleteFile(GetSwapFileWholeName(file_name));
|
||||
}
|
||||
|
||||
bool IOHandle::CreateSwapFile(const std::string &file_name) const {
|
||||
const auto &fs = system::Env::GetFileSystem();
|
||||
MS_EXCEPTION_IF_NULL(fs);
|
||||
auto file = fs->CreateWriteFile(GetSwapFileWholeName(file_name));
|
||||
MS_EXCEPTION_IF_NULL(file);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string IOHandle::GetSwapFileWholeName(const std::string &file_name) const {
|
||||
return swap_path_ + file_name + ".data";
|
||||
}
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,75 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_UTILS_SYSTEM_IO_HANDLE_H_
|
||||
#define MINDSPORE_CORE_UTILS_SYSTEM_IO_HANDLE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include "include/backend/visible.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
using AsyncIOToken = size_t;
|
||||
constexpr AsyncIOToken kInvalidAsyncIOToken = 0;
|
||||
|
||||
struct AsyncIOConf {
|
||||
size_t block_size;
|
||||
size_t queue_depth;
|
||||
};
|
||||
|
||||
class AsyncIO {
|
||||
public:
|
||||
AsyncIO() = default;
|
||||
virtual ~AsyncIO() = default;
|
||||
virtual bool Init(const AsyncIOConf &conf) = 0;
|
||||
virtual bool Read(const std::string &file_name, void *data, size_t byte_num) = 0;
|
||||
virtual bool Write(const std::string &file_name, const void *data, size_t byte_num) = 0;
|
||||
virtual bool ReadAsync(const std::string &file_name, void *data, size_t byte_num, AsyncIOToken *token) = 0;
|
||||
virtual bool WriteAsync(const std::string &file_name, const void *data, size_t byte_num, AsyncIOToken *token) = 0;
|
||||
virtual bool Wait(AsyncIOToken token) = 0;
|
||||
};
|
||||
|
||||
class BACKEND_EXPORT IOHandle {
|
||||
public:
|
||||
IOHandle() = default;
|
||||
~IOHandle() = default;
|
||||
bool DeleteSwapFile(const std::string &file_name) const;
|
||||
bool CreateSwapFile(const std::string &file_name) const;
|
||||
bool Read(const std::string &file_name, void *data, size_t byte_num);
|
||||
bool Write(const std::string &file_name, const void *data, size_t byte_num);
|
||||
bool ReadAsync(const std::string &file_name, void *data, size_t byte_num, AsyncIOToken *token);
|
||||
bool WriteAsync(const std::string &file_name, const void *data, size_t byte_num, AsyncIOToken *token);
|
||||
bool Wait(AsyncIOToken sync_token);
|
||||
|
||||
private:
|
||||
std::string GetSwapFileWholeName(const std::string &file_name) const;
|
||||
|
||||
private:
|
||||
bool use_aio_{false};
|
||||
size_t aio_block_size_{0};
|
||||
size_t aio_queue_depth_{0};
|
||||
std::string swap_path_;
|
||||
std::shared_ptr<AsyncIO> aio_{nullptr};
|
||||
};
|
||||
using IOHandlePtr = std::shared_ptr<IOHandle>;
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_UTILS_SYSTEM_IO_HANDLE_H_
|
|
@ -0,0 +1,190 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/device/gsm/swap_manager.h"
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
SwapManager::SwapManager(size_t stream_id, mindspore::device::DynamicMemPoolBestFit *device_memory_pool,
|
||||
HostMemPoolPtr host_mem_pool)
|
||||
: stream_id_(stream_id), device_memory_pool_(device_memory_pool), host_memory_pool_(std::move(host_mem_pool)) {
|
||||
io_handle_ = std::make_shared<IOHandle>();
|
||||
}
|
||||
|
||||
template <class Input, class Output>
|
||||
bool SwapManager::TryAllocate(std::queue<const DeviceAddress *> queue, const Input &input,
|
||||
Output (SwapManager::*allocate_func)(const Input &),
|
||||
const std::function<bool(Output)> &success, Output *output) {
|
||||
MS_EXCEPTION_IF_NULL(allocate_func);
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
(*output) = (this->*allocate_func)(input);
|
||||
if (success(*output)) {
|
||||
return true;
|
||||
}
|
||||
// Wait swapping tensors.
|
||||
while (!queue.empty()) {
|
||||
const auto &front = queue.front();
|
||||
MS_EXCEPTION_IF_NULL(front);
|
||||
if (front->Wait()) {
|
||||
(*output) = (this->*allocate_func)(input);
|
||||
if (success(*output)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
queue.pop();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void *SwapManager::AllocDeviceMemorySimply(const size_t &size) {
|
||||
MS_EXCEPTION_IF_NULL(device_memory_pool_);
|
||||
return device_memory_pool_->AllocTensorMem(size);
|
||||
}
|
||||
|
||||
void *SwapManager::AllocDeviceMemory(size_t size) {
|
||||
void *ret = nullptr;
|
||||
void *(SwapManager::*allocate_func)(const size_t &) = &SwapManager::AllocDeviceMemorySimply;
|
||||
std::function<bool(void *)> success = [](void *ptr) { return ptr != nullptr; };
|
||||
std::lock_guard<std::mutex> lock(swapping_tensors_device_mutex_);
|
||||
if (!TryAllocate(swapping_tensors_device_, size, allocate_func, success, &ret)) {
|
||||
MS_LOG(WARNING) << "Allocate device memory failed, size: " << size;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<void *> SwapManager::AllocDeviceContinuousMemSimply(const std::vector<size_t> &size_list) {
|
||||
MS_EXCEPTION_IF_NULL(device_memory_pool_);
|
||||
return device_memory_pool_->AllocContinuousTensorMem(size_list);
|
||||
}
|
||||
|
||||
std::vector<void *> SwapManager::AllocDeviceContinuousMem(const std::vector<size_t> &size_list) {
|
||||
std::vector<void *> ret;
|
||||
std::vector<void *> (SwapManager::*allocate_func)(const std::vector<size_t> &) =
|
||||
&SwapManager::AllocDeviceContinuousMemSimply;
|
||||
std::function<bool(std::vector<void *>)> success = [](const std::vector<void *> &ptrs) { return !ptrs.empty(); };
|
||||
std::lock_guard<std::mutex> lock(swapping_tensors_device_mutex_);
|
||||
if (!TryAllocate(swapping_tensors_device_, size_list, allocate_func, success, &ret)) {
|
||||
MS_LOG(WARNING) << "Allocate continuous device mem failed, size list: " << size_list;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
void SwapManager::FreeDeviceMemory(void *ptr) {
|
||||
MS_EXCEPTION_IF_NULL(device_memory_pool_);
|
||||
device_memory_pool_->FreeTensorMem(ptr);
|
||||
}
|
||||
|
||||
void *SwapManager::AllocHostMemorySimply(const size_t &size) {
|
||||
MS_EXCEPTION_IF_NULL(host_memory_pool_);
|
||||
return host_memory_pool_->Malloc(size);
|
||||
}
|
||||
|
||||
void *SwapManager::AllocHostMemory(size_t size) {
|
||||
void *ret = nullptr;
|
||||
void *(SwapManager::*allocate_func)(const size_t &) = &SwapManager::AllocHostMemorySimply;
|
||||
std::function<bool(void *)> success = [](void *ptr) { return ptr != nullptr; };
|
||||
std::lock_guard<std::mutex> lock(swapping_tensors_host_mutex_);
|
||||
if (!TryAllocate(swapping_tensors_host_, size, allocate_func, success, &ret)) {
|
||||
MS_LOG(WARNING) << "Allocate host memory failed, size: " << size;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
void SwapManager::FreeHostMemory(void *ptr) {
|
||||
MS_EXCEPTION_IF_NULL(host_memory_pool_);
|
||||
host_memory_pool_->Free(ptr);
|
||||
}
|
||||
|
||||
bool SwapManager::CreateFile(const std::string &file_name) {
|
||||
MS_EXCEPTION_IF_NULL(io_handle_);
|
||||
return io_handle_->CreateSwapFile(file_name);
|
||||
}
|
||||
|
||||
bool SwapManager::DeleteFile(const std::string &file_name) {
|
||||
MS_EXCEPTION_IF_NULL(io_handle_);
|
||||
const auto &iter = file_size_.find(file_name);
|
||||
if (iter == file_size_.end()) {
|
||||
MS_LOG(WARNING) << "Can not file size for file[" << file_name << "]";
|
||||
} else {
|
||||
current_used_file_size_ -= iter->second;
|
||||
iter->second = 0;
|
||||
}
|
||||
return io_handle_->DeleteSwapFile(file_name);
|
||||
}
|
||||
|
||||
bool SwapManager::FileToHostMemory(void *host_memory, const std::string &file_name, size_t byte_num, bool async,
|
||||
AsyncIOToken *sync_key) {
|
||||
MS_EXCEPTION_IF_NULL(io_handle_);
|
||||
if (async) {
|
||||
return io_handle_->ReadAsync(file_name, host_memory, byte_num, sync_key);
|
||||
} else {
|
||||
return io_handle_->Read(file_name, host_memory, byte_num);
|
||||
}
|
||||
}
|
||||
|
||||
bool SwapManager::EnoughFileSpace(const size_t &size) { return current_used_file_size_ + size <= max_file_size_; }
|
||||
|
||||
bool SwapManager::HostMemoryToFile(const std::string &file_name, const void *data, size_t byte_num, bool async,
|
||||
AsyncIOToken *sync_key) {
|
||||
MS_EXCEPTION_IF_NULL(io_handle_);
|
||||
bool (SwapManager::*allocate_func)(const size_t &size) = &SwapManager::EnoughFileSpace;
|
||||
std::function<bool(bool)> success = [](bool ret) { return ret; };
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(swapping_tensors_file_mutex_);
|
||||
bool enough = false;
|
||||
if (!TryAllocate(swapping_tensors_file_, byte_num, allocate_func, success, &enough)) {
|
||||
MS_LOG(WARNING) << "There is no enough disk space for file writing, size: " << byte_num;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
current_used_file_size_ += byte_num;
|
||||
file_size_[file_name] = byte_num;
|
||||
if (async) {
|
||||
return io_handle_->WriteAsync(file_name, data, byte_num, sync_key);
|
||||
} else {
|
||||
return io_handle_->Write(file_name, data, byte_num);
|
||||
}
|
||||
}
|
||||
|
||||
bool SwapManager::WaitAsyncIO(mindspore::device::AsyncIOToken sync_token) {
|
||||
MS_EXCEPTION_IF_NULL(io_handle_);
|
||||
return io_handle_->Wait(sync_token);
|
||||
}
|
||||
|
||||
void SwapManager::AddSwappableTensor(const DeviceAddressPtr &device_address) {
|
||||
swappable_tensors_.push(device_address);
|
||||
}
|
||||
|
||||
void SwapManager::AddSwappingTensor(const mindspore::device::DeviceAddress *device_address) {
|
||||
if (device_address == nullptr) {
|
||||
return;
|
||||
}
|
||||
if (device_address->status() == DeviceAddressStatus::kInFileToHost) {
|
||||
std::lock_guard<std::mutex> lock(swapping_tensors_file_mutex_);
|
||||
(void)swapping_tensors_file_.push(device_address);
|
||||
} else if (device_address->status() == DeviceAddressStatus::kInDeviceToHost) {
|
||||
std::lock_guard<std::mutex> lock(swapping_tensors_device_mutex_);
|
||||
(void)swapping_tensors_device_.push(device_address);
|
||||
} else {
|
||||
std::lock_guard<std::mutex> lock(swapping_tensors_host_mutex_);
|
||||
(void)swapping_tensors_host_.push(device_address);
|
||||
}
|
||||
}
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,91 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GSM_SWAP_MANAGER_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GSM_SWAP_MANAGER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "common/mem_reuse/mem_dynamic_allocator.h"
|
||||
#include "runtime/device/device_address.h"
|
||||
#include "runtime/device/gsm/io_handle.h"
|
||||
#include "runtime/device/gsm/host_mem_pool.h"
|
||||
#include "include/backend/visible.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
class BACKEND_EXPORT SwapManager {
|
||||
public:
|
||||
SwapManager(size_t stream_id, DynamicMemPoolBestFit *device_memory_pool, HostMemPoolPtr host_memory_pool);
|
||||
~SwapManager() = default;
|
||||
// Device memory
|
||||
void *AllocDeviceMemory(size_t size);
|
||||
std::vector<void *> AllocDeviceContinuousMem(const std::vector<size_t> &size_list);
|
||||
void FreeDeviceMemory(void *ptr);
|
||||
|
||||
// Host memory
|
||||
void *AllocHostMemory(size_t size);
|
||||
void FreeHostMemory(void *ptr);
|
||||
|
||||
// File
|
||||
bool CreateFile(const std::string &file_name);
|
||||
bool DeleteFile(const std::string &file_name);
|
||||
bool FileToHostMemory(void *host_memory, const std::string &file_name, size_t byte_num, bool async,
|
||||
AsyncIOToken *sync_token);
|
||||
bool HostMemoryToFile(const std::string &file_name, const void *data, size_t byte_num, bool async,
|
||||
AsyncIOToken *sync_token);
|
||||
bool WaitAsyncIO(AsyncIOToken sync_token);
|
||||
|
||||
// Swapping and swappable tensors
|
||||
void AddSwappableTensor(const DeviceAddressPtr &device_address);
|
||||
void AddSwappingTensor(const DeviceAddress *device_address);
|
||||
|
||||
private:
|
||||
void *AllocDeviceMemorySimply(const size_t &size);
|
||||
std::vector<void *> AllocDeviceContinuousMemSimply(const std::vector<size_t> &size_list);
|
||||
void *AllocHostMemorySimply(const size_t &size);
|
||||
bool EnoughFileSpace(const size_t &size);
|
||||
|
||||
template <class Input, class Output>
|
||||
bool TryAllocate(std::queue<const DeviceAddress *> queue, const Input &input,
|
||||
Output (SwapManager::*allocate_func)(const Input &), const std::function<bool(Output)> &success,
|
||||
Output *output);
|
||||
|
||||
private:
|
||||
size_t stream_id_;
|
||||
DynamicMemPoolBestFit *device_memory_pool_;
|
||||
HostMemPoolPtr host_memory_pool_;
|
||||
size_t max_file_size_{0};
|
||||
size_t current_used_file_size_{0};
|
||||
HashMap<std::string, size_t> file_size_;
|
||||
struct compare {
|
||||
bool operator()(const DeviceAddressPtr &l, const DeviceAddressPtr &r) const { return l->GetSize() < r->GetSize(); }
|
||||
};
|
||||
std::priority_queue<DeviceAddressPtr, std::vector<DeviceAddressPtr>, compare> swappable_tensors_;
|
||||
std::mutex swapping_tensors_device_mutex_;
|
||||
std::queue<const DeviceAddress *> swapping_tensors_device_;
|
||||
std::mutex swapping_tensors_host_mutex_;
|
||||
std::queue<const DeviceAddress *> swapping_tensors_host_;
|
||||
std::mutex swapping_tensors_file_mutex_;
|
||||
std::queue<const DeviceAddress *> swapping_tensors_file_;
|
||||
IOHandlePtr io_handle_;
|
||||
};
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_GSM_SWAP_MANAGER_H_
|
|
@ -20,9 +20,10 @@
|
|||
#include <queue>
|
||||
#include "runtime/device/gsm/swap_strategy.h"
|
||||
#include "runtime/device/gsm/mem_usage_analyzer.h"
|
||||
#include "include/backend/visible.h"
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
class SwapStrategyBuilder {
|
||||
class BACKEND_EXPORT SwapStrategyBuilder {
|
||||
public:
|
||||
SwapStrategyBuilder() = default;
|
||||
~SwapStrategyBuilder() = default;
|
||||
|
|
|
@ -0,0 +1,324 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/device/loadable_device_address.h"
|
||||
#include "include/common/debug/common.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace {
|
||||
constexpr size_t kFileAlignSize = 512;
|
||||
} // namespace
|
||||
|
||||
bool LoadableDeviceAddress::Offload(size_t stream_id) {
|
||||
std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
|
||||
if (mem_offloaded_) {
|
||||
MS_LOG(WARNING) << "Trying to offload an offloaded AscendDeviceAddress.";
|
||||
return true;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(ptr_);
|
||||
auto device_context = GetDeviceContext();
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
offload_ptr_ = device_context->device_res_manager_->AllocateOffloadMemory(size_);
|
||||
if (offload_ptr_ == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Alloc host memory for offloading failed, size: " << size_ << ".";
|
||||
}
|
||||
if (!AsyncDeviceToHost({}, size_, kTypeUnknown, offload_ptr_, stream_id)) {
|
||||
return false;
|
||||
}
|
||||
device_context->device_res_manager_->FreeMemory(ptr_);
|
||||
ptr_ = nullptr;
|
||||
mem_offloaded_ = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool LoadableDeviceAddress::Load(size_t stream_id) {
|
||||
std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
|
||||
if (!mem_offloaded_) {
|
||||
MS_LOG(DEBUG) << "Trying to load a loaded AscendDeviceAddress.";
|
||||
return true;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(offload_ptr_);
|
||||
auto device_context = GetDeviceContext();
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
if (ptr_ == nullptr && !device_context->device_res_manager_->AllocateMemory(this)) {
|
||||
MS_LOG(EXCEPTION) << "Alloc memory for loading failed, size: " << size_ << ".";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(ptr_);
|
||||
if (!AsyncHostToDevice({}, size_, kTypeUnknown, offload_ptr_, stream_id)) {
|
||||
return false;
|
||||
}
|
||||
device_context->device_res_manager_->FreeOffloadMemory(offload_ptr_);
|
||||
offload_ptr_ = nullptr;
|
||||
mem_offloaded_ = false;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool LoadableDeviceAddress::MoveTo(mindspore::device::StorageType dst, bool async, size_t stream_id) {
|
||||
bool ret = Wait();
|
||||
if (!ret) {
|
||||
MS_LOG(WARNING) << "Wait swapping DeviceAddress failed. Status: " << status_;
|
||||
return false;
|
||||
}
|
||||
if (dst == StorageType::kDevice) {
|
||||
if (!MoveToDevice(async, stream_id)) {
|
||||
MS_LOG(WARNING) << "Move data to device failed.";
|
||||
return false;
|
||||
}
|
||||
} else if (dst == StorageType::kHost) {
|
||||
if (!MoveToHost(async, stream_id)) {
|
||||
MS_LOG(WARNING) << "Move data to host failed.";
|
||||
return false;
|
||||
}
|
||||
} else if (dst == StorageType::kFile) {
|
||||
if (!MoveToFile(async, stream_id)) {
|
||||
MS_LOG(WARNING) << "Move data to file failed.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool LoadableDeviceAddress::MoveToHost(bool async, size_t stream_id) const {
|
||||
const auto device_context = GetDeviceContext();
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
const auto swap_manager = device_context->device_res_manager_->swap_manager();
|
||||
MS_EXCEPTION_IF_NULL(swap_manager);
|
||||
storage_info_.host_ptr_ = swap_manager->AllocHostMemory(GetFileAlignSize());
|
||||
MS_EXCEPTION_IF_NULL(storage_info_.host_ptr_);
|
||||
if (status_ == DeviceAddressStatus::kInFile) {
|
||||
if (!CopyFileToHost(storage_info_.host_ptr_, storage_info_.file_name_, GetFileAlignSize(), async)) {
|
||||
MS_LOG(WARNING) << "Copy data from file to host failed.";
|
||||
return false;
|
||||
}
|
||||
if (async) {
|
||||
swap_manager->AddSwappingTensor(this);
|
||||
status_ = DeviceAddressStatus::kInFileToHost;
|
||||
} else {
|
||||
swap_manager->DeleteFile(storage_info_.file_name_);
|
||||
storage_info_.file_name_ = "";
|
||||
status_ = DeviceAddressStatus::kInHost;
|
||||
}
|
||||
} else {
|
||||
if (!CopyDeviceToHost(storage_info_.host_ptr_, ptr_, size_, async, stream_id)) {
|
||||
MS_LOG(WARNING) << "Copy data from device to host failed.";
|
||||
return false;
|
||||
}
|
||||
if (async) {
|
||||
swap_manager->AddSwappingTensor(this);
|
||||
status_ = DeviceAddressStatus::kInDeviceToHost;
|
||||
} else {
|
||||
swap_manager->FreeDeviceMemory(ptr_);
|
||||
ptr_ = nullptr;
|
||||
status_ = DeviceAddressStatus::kInHost;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool LoadableDeviceAddress::MoveToDevice(bool async, size_t stream_id) const {
|
||||
if (status_ == DeviceAddressStatus::kInDevice) {
|
||||
return true;
|
||||
}
|
||||
const auto device_context = GetDeviceContext();
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
const auto swap_manager = device_context->device_res_manager_->swap_manager();
|
||||
MS_EXCEPTION_IF_NULL(swap_manager);
|
||||
if (status_ == DeviceAddressStatus::kInFile && !MoveToHost(false, stream_id)) {
|
||||
return false;
|
||||
}
|
||||
ptr_ = swap_manager->AllocDeviceMemory(size_);
|
||||
MS_EXCEPTION_IF_NULL(ptr_);
|
||||
if (!CopyHostToDevice(ptr_, storage_info_.host_ptr_, size_, async, stream_id)) {
|
||||
MS_LOG(WARNING) << "Copy data from host to device failed.";
|
||||
return false;
|
||||
}
|
||||
if (async) {
|
||||
swap_manager->AddSwappingTensor(this);
|
||||
status_ = DeviceAddressStatus::kInHostToDevice;
|
||||
} else {
|
||||
swap_manager->FreeHostMemory(storage_info_.host_ptr_);
|
||||
storage_info_.host_ptr_ = nullptr;
|
||||
status_ = DeviceAddressStatus::kInDevice;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool LoadableDeviceAddress::MoveToFile(bool async, size_t stream_id) const {
|
||||
if (status_ == DeviceAddressStatus::kInFile) {
|
||||
return true;
|
||||
}
|
||||
const auto device_context = GetDeviceContext();
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
const auto swap_manager = device_context->device_res_manager_->swap_manager();
|
||||
MS_EXCEPTION_IF_NULL(swap_manager);
|
||||
if (status_ == DeviceAddressStatus::kInDevice && !MoveToHost(false, stream_id)) {
|
||||
return false;
|
||||
}
|
||||
storage_info_.file_name_ = GetSwapFileName();
|
||||
if (!swap_manager->CreateFile(storage_info_.file_name_)) {
|
||||
MS_LOG(WARNING) << "Create file for swapping failed.";
|
||||
return false;
|
||||
}
|
||||
if (!CopyHostToFile(storage_info_.file_name_, storage_info_.host_ptr_, GetFileAlignSize(), async)) {
|
||||
MS_LOG(WARNING) << "Copy data from host to file failed.";
|
||||
return false;
|
||||
}
|
||||
if (async) {
|
||||
swap_manager->AddSwappingTensor(this);
|
||||
status_ = DeviceAddressStatus::kInHostToFile;
|
||||
} else {
|
||||
swap_manager->FreeHostMemory(storage_info_.host_ptr_);
|
||||
storage_info_.host_ptr_ = nullptr;
|
||||
status_ = DeviceAddressStatus::kInFile;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool LoadableDeviceAddress::CopyHostToFile(const std::string &dst, const void *src, size_t size, bool async) const {
|
||||
MS_EXCEPTION_IF_NULL(src);
|
||||
const auto device_context = GetDeviceContext();
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
const auto swap_manager = device_context->device_res_manager_->swap_manager();
|
||||
MS_EXCEPTION_IF_NULL(swap_manager);
|
||||
AsyncIOToken token;
|
||||
bool ret = swap_manager->HostMemoryToFile(dst, src, size, async, &token);
|
||||
if (!ret) {
|
||||
MS_LOG(WARNING) << "Write data from ddr to file[" << dst << "] failed.";
|
||||
return ret;
|
||||
}
|
||||
if (async) {
|
||||
MS_EXCEPTION_IF_NULL(swap_event_);
|
||||
swap_event_->aio_token_ = token;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool LoadableDeviceAddress::CopyFileToHost(void *dst, const std::string &src, size_t size, bool async) const {
|
||||
MS_EXCEPTION_IF_NULL(dst);
|
||||
const auto device_context = GetDeviceContext();
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
const auto swap_manager = device_context->device_res_manager_->swap_manager();
|
||||
MS_EXCEPTION_IF_NULL(swap_manager);
|
||||
AsyncIOToken token;
|
||||
bool ret = swap_manager->FileToHostMemory(dst, src, size, async, &token);
|
||||
if (!ret) {
|
||||
MS_LOG(WARNING) << "Read data from file[" << src << "] to ddr failed.";
|
||||
return ret;
|
||||
}
|
||||
if (async) {
|
||||
MS_EXCEPTION_IF_NULL(swap_event_);
|
||||
swap_event_->aio_token_ = token;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t LoadableDeviceAddress::GetFileAlignSize() const {
|
||||
return (size_ + kFileAlignSize - 1) / kFileAlignSize * kFileAlignSize;
|
||||
}
|
||||
|
||||
std::string LoadableDeviceAddress::GetSwapFileName() const {
|
||||
static size_t swap_file_index = 0;
|
||||
return std::to_string(device_id()) + "_" + std::to_string(swap_file_index++) + "_" +
|
||||
std::to_string(Common::GetTimeStamp());
|
||||
}
|
||||
|
||||
void LoadableDeviceAddress::SetStorageInfo(const mindspore::device::StorageInfo &storage_info) {
|
||||
std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
|
||||
storage_info_ = storage_info;
|
||||
if (storage_info_.host_ptr_ != nullptr) {
|
||||
status_ = DeviceAddressStatus::kInHost;
|
||||
} else if (!storage_info_.file_name_.empty()) {
|
||||
status_ = DeviceAddressStatus::kInFile;
|
||||
} else {
|
||||
status_ = DeviceAddressStatus::kInDevice;
|
||||
}
|
||||
}
|
||||
|
||||
void LoadableDeviceAddress::HandOver(mindspore::device::DeviceAddress *other) {
|
||||
DeviceAddress::HandOver(other);
|
||||
auto loadable_device_address = reinterpret_cast<LoadableDeviceAddress *>(other);
|
||||
if (loadable_device_address != nullptr) {
|
||||
loadable_device_address->SetStorageInfo(storage_info_);
|
||||
loadable_device_address->set_status(status_);
|
||||
storage_info_.host_ptr_ = nullptr;
|
||||
storage_info_.file_name_ = "";
|
||||
status_ = DeviceAddressStatus::kInDevice;
|
||||
}
|
||||
}
|
||||
|
||||
bool LoadableDeviceAddress::Wait() const {
|
||||
if (swap_event_ == nullptr || !swap_event_->NeedWait()) {
|
||||
return true;
|
||||
}
|
||||
const auto device_context = GetDeviceContext();
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
const auto swap_manager = device_context->device_res_manager_->swap_manager();
|
||||
MS_EXCEPTION_IF_NULL(swap_manager);
|
||||
if (swap_event_->device_event_ != nullptr && swap_event_->device_event_->NeedWait()) {
|
||||
swap_event_->device_event_->WaitEvent();
|
||||
} else if (swap_event_->aio_token_ != kInvalidAsyncIOToken) {
|
||||
if (!swap_manager->WaitAsyncIO(swap_event_->aio_token_)) {
|
||||
MS_LOG(WARNING) << "Wait aio failed.";
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Device address is in moving, but no valid swap event can be found.";
|
||||
}
|
||||
if (status_ == DeviceAddressStatus::kInFileToHost) {
|
||||
swap_manager->DeleteFile(storage_info_.file_name_);
|
||||
status_ = DeviceAddressStatus::kInHost;
|
||||
} else if (status_ == DeviceAddressStatus::kInDeviceToHost) {
|
||||
swap_manager->FreeDeviceMemory(ptr_);
|
||||
status_ = DeviceAddressStatus::kInHost;
|
||||
} else {
|
||||
swap_manager->FreeHostMemory(storage_info_.host_ptr_);
|
||||
if (status_ == DeviceAddressStatus::kInHostToDevice) {
|
||||
status_ = DeviceAddressStatus::kInHost;
|
||||
} else {
|
||||
status_ = DeviceAddressStatus::kInFile;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void LoadableDeviceAddress::SetOffloadPtr(void *offload_ptr) {
|
||||
std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
|
||||
offload_ptr_ = offload_ptr;
|
||||
mem_offloaded_ = (offload_ptr != nullptr);
|
||||
}
|
||||
|
||||
void *LoadableDeviceAddress::GetOffloadPtr() const {
|
||||
std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
|
||||
return offload_ptr_;
|
||||
}
|
||||
|
||||
// Return whether DeviceAddress has a valid ptr.
|
||||
bool LoadableDeviceAddress::IsPtrValid() const {
|
||||
std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
|
||||
return ptr_ != nullptr || offload_ptr_ != nullptr;
|
||||
}
|
||||
|
||||
// Load first if data is offloaded and return the device ptr.
|
||||
void *LoadableDeviceAddress::GetValidPtr(size_t stream_id) {
|
||||
std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
|
||||
if (mem_offloaded() && !Load(stream_id)) {
|
||||
MS_LOG(EXCEPTION) << "Load offloaded memory failed";
|
||||
}
|
||||
return ptr_;
|
||||
}
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
|
@ -17,6 +17,7 @@
|
|||
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_LOADABLE_DEVICE_ADDRESS_H_
|
||||
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_LOADABLE_DEVICE_ADDRESS_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "runtime/device/device_address.h"
|
||||
#include "runtime/hardware/device_context.h"
|
||||
|
@ -24,108 +25,98 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
// LoadableDeviceAddress provide the ability to offload data on device to host and load it back later.
|
||||
class LoadableDeviceAddress : public DeviceAddress {
|
||||
struct SwapEvent {
|
||||
bool NeedWait() {
|
||||
return aio_token_ != kInvalidAsyncIOToken || (device_event_ != nullptr && device_event_->NeedWait());
|
||||
}
|
||||
AsyncIOToken aio_token_{kInvalidAsyncIOToken};
|
||||
std::shared_ptr<DeviceEvent> device_event_;
|
||||
};
|
||||
using SwapEventPtr = std::shared_ptr<SwapEvent>;
|
||||
|
||||
// LoadableDeviceAddress provide the ability to offload data on device to ddr or disk and load it back later.
|
||||
class BACKEND_EXPORT LoadableDeviceAddress : public DeviceAddress {
|
||||
public:
|
||||
LoadableDeviceAddress(void *ptr, size_t size) : DeviceAddress(ptr, size) {}
|
||||
LoadableDeviceAddress(void *ptr, size_t size) : DeviceAddress(ptr, size) {
|
||||
swap_event_ = std::make_shared<SwapEvent>();
|
||||
}
|
||||
LoadableDeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id)
|
||||
: DeviceAddress(ptr, size, format, type_id) {}
|
||||
: DeviceAddress(ptr, size, format, type_id) {
|
||||
swap_event_ = std::make_shared<SwapEvent>();
|
||||
}
|
||||
LoadableDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id,
|
||||
const KernelWithIndex &node_index)
|
||||
: DeviceAddress(ptr, size, format, type_id, node_index) {}
|
||||
: DeviceAddress(ptr, size, format, type_id, node_index) {
|
||||
swap_event_ = std::make_shared<SwapEvent>();
|
||||
}
|
||||
LoadableDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id,
|
||||
const std::string &device_name, uint32_t device_id)
|
||||
: DeviceAddress(ptr, size, format, type_id, device_name, device_id) {}
|
||||
: DeviceAddress(ptr, size, format, type_id, device_name, device_id) {
|
||||
swap_event_ = std::make_shared<SwapEvent>();
|
||||
}
|
||||
LoadableDeviceAddress(void *ptr, size_t size, const std::string &device_name, uint32_t device_id)
|
||||
: DeviceAddress(ptr, size, device_name, device_id) {}
|
||||
: DeviceAddress(ptr, size, device_name, device_id) {
|
||||
swap_event_ = std::make_shared<SwapEvent>();
|
||||
}
|
||||
LoadableDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id,
|
||||
const KernelWithIndex &node_index, const std::string &device_name, uint32_t device_id)
|
||||
: DeviceAddress(ptr, size, format, type_id, node_index, device_name, device_id) {}
|
||||
: DeviceAddress(ptr, size, format, type_id, node_index, device_name, device_id) {
|
||||
swap_event_ = std::make_shared<SwapEvent>();
|
||||
}
|
||||
|
||||
bool mem_offloaded() const final { return mem_offloaded_; }
|
||||
|
||||
// Offload data from device to host and free device memory
|
||||
bool Offload(size_t stream_id) final;
|
||||
|
||||
// Load data from host to device and free host memory
|
||||
bool Load(size_t stream_id) final;
|
||||
|
||||
// Move data to destination hardware and free resource on source hardware
|
||||
bool MoveTo(StorageType dest, bool async, size_t stream_id) override;
|
||||
|
||||
bool Wait() const override;
|
||||
|
||||
void SetStorageInfo(const StorageInfo &storage_info) final;
|
||||
|
||||
// Set host ptr data offloaded to
|
||||
void SetOffloadPtr(void *offload_ptr) final;
|
||||
// Get offloaded host ptr
|
||||
void *GetOffloadPtr() const final;
|
||||
|
||||
// Return whether DeviceAddress has a valid ptr.
|
||||
bool IsPtrValid() const final;
|
||||
|
||||
// Load first if data is offloaded and return the device ptr.
|
||||
void *GetValidPtr(size_t stream_id) final;
|
||||
|
||||
void HandOver(DeviceAddress *other) override;
|
||||
|
||||
protected:
|
||||
DeviceContext *GetDeviceContext() const {
|
||||
return DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name_, device_id_});
|
||||
}
|
||||
|
||||
bool MoveToDevice(bool async, size_t stream_id) const;
|
||||
bool MoveToHost(bool async, size_t stream_id) const;
|
||||
bool MoveToFile(bool async, size_t stream_id) const;
|
||||
|
||||
virtual bool CopyDeviceToHost(void *dst, const void *src, size_t size, bool async, size_t stream_id) const {
|
||||
return false;
|
||||
}
|
||||
virtual bool CopyHostToDevice(void *dst, const void *src, size_t size, bool async, size_t stream_id) const {
|
||||
return false;
|
||||
}
|
||||
virtual bool CopyHostToFile(const std::string &dst, const void *src, size_t size, bool async) const;
|
||||
virtual bool CopyFileToHost(void *dst, const std::string &src, size_t size, bool async) const;
|
||||
|
||||
std::string GetSwapFileName() const;
|
||||
size_t GetFileAlignSize() const;
|
||||
|
||||
bool mem_offloaded_{false};
|
||||
void *offload_ptr_{nullptr};
|
||||
|
||||
public:
|
||||
bool mem_offloaded() const final { return mem_offloaded_; }
|
||||
|
||||
// Offload data from device to host and free device memory
|
||||
bool Offload(size_t stream_id) final {
|
||||
std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
|
||||
if (mem_offloaded_) {
|
||||
MS_LOG(WARNING) << "Trying to offload an offloaded AscendDeviceAddress.";
|
||||
return true;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(ptr_);
|
||||
auto device_context = GetDeviceContext();
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
offload_ptr_ = device_context->device_res_manager_->AllocateOffloadMemory(size_);
|
||||
if (offload_ptr_ == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Alloc host memory for offloading failed, size: " << size_ << ".";
|
||||
}
|
||||
if (!AsyncDeviceToHost({}, size_, kTypeUnknown, offload_ptr_, stream_id)) {
|
||||
return false;
|
||||
}
|
||||
device_context->device_res_manager_->FreeMemory(ptr_);
|
||||
ptr_ = nullptr;
|
||||
mem_offloaded_ = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Load data from host to device and free host memory
|
||||
bool Load(size_t stream_id) final {
|
||||
std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
|
||||
if (!mem_offloaded_) {
|
||||
MS_LOG(DEBUG) << "Trying to load a loaded AscendDeviceAddress.";
|
||||
return true;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(offload_ptr_);
|
||||
auto device_context = GetDeviceContext();
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
if (ptr_ == nullptr && !device_context->device_res_manager_->AllocateMemory(this)) {
|
||||
MS_LOG(EXCEPTION) << "Alloc memory for loading failed, size: " << size_ << ".";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(ptr_);
|
||||
if (!AsyncHostToDevice({}, size_, kTypeUnknown, offload_ptr_, stream_id)) {
|
||||
return false;
|
||||
}
|
||||
device_context->device_res_manager_->FreeOffloadMemory(offload_ptr_);
|
||||
offload_ptr_ = nullptr;
|
||||
mem_offloaded_ = false;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Set host ptr data offloaded to
|
||||
void SetOffloadPtr(void *offload_ptr) final {
|
||||
std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
|
||||
offload_ptr_ = offload_ptr;
|
||||
mem_offloaded_ = (offload_ptr != nullptr);
|
||||
}
|
||||
|
||||
// Get offloaded host ptr
|
||||
void *GetOffloadPtr() const final {
|
||||
std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
|
||||
return offload_ptr_;
|
||||
}
|
||||
|
||||
// Return whether DeviceAddress has a valid ptr.
|
||||
bool IsPtrValid() const final {
|
||||
std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
|
||||
return ptr_ != nullptr || offload_ptr_ != nullptr;
|
||||
}
|
||||
|
||||
// Load first if data is offloaded and return the device ptr.
|
||||
void *GetValidPtr(size_t stream_id) final {
|
||||
std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
|
||||
if (mem_offloaded() && !Load(stream_id)) {
|
||||
MS_LOG(EXCEPTION) << "Load offloaded memory failed.";
|
||||
}
|
||||
return ptr_;
|
||||
}
|
||||
SwapEventPtr swap_event_{nullptr};
|
||||
mutable StorageInfo storage_info_;
|
||||
};
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -59,14 +59,16 @@ void MemorySwapActor::UpdateDeviceTensors(OpContext<mindspore::runtime::DeviceTe
|
|||
}
|
||||
}
|
||||
|
||||
void MemorySwapActor::GetDeviceTensors(std::vector<size_t> indexes, std::vector<DeviceTensor *> *device_tensors) {
|
||||
std::vector<DeviceTensor *> MemorySwapActor::GetDeviceTensors(const std::vector<size_t> &indexes) {
|
||||
std::vector<DeviceTensor *> device_tensors;
|
||||
for (const auto index : indexes) {
|
||||
if (index >= device_tensors_to_swap_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Device tensor index[" << index << "] out of range[" << device_tensors_to_swap_.size()
|
||||
<< "].";
|
||||
}
|
||||
device_tensors->emplace_back(device_tensors_to_swap_[index]);
|
||||
device_tensors.emplace_back(device_tensors_to_swap_[index]);
|
||||
}
|
||||
return std::move(device_tensors);
|
||||
}
|
||||
|
||||
void MemorySwapActor::AllocDeviceContinuousMem(const std::vector<DeviceTensor *> &device_tensors) {
|
||||
|
@ -85,34 +87,31 @@ void MemorySwapActor::AllocDeviceContinuousMem(const std::vector<DeviceTensor *>
|
|||
}
|
||||
}
|
||||
|
||||
void MemorySwapActor::Swap(device::StorageType from, device::StorageType to,
|
||||
const std::vector<DeviceTensor *> &device_tensors) {
|
||||
void MemorySwapActor::Swap(device::StorageType to, const std::vector<DeviceTensor *> &device_tensors) {
|
||||
for (const auto &device_tensor : device_tensors) {
|
||||
device_tensor->MoveTo(to, false);
|
||||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
device_tensor->MoveTo(to, false, kDefaultStreamIndex);
|
||||
}
|
||||
}
|
||||
|
||||
void MemorySwapActor::Run(OpContext<mindspore::runtime::DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
static std::map<device::SwapActionType, std::pair<device::StorageType, device::StorageType>> swap_from_to_map = {
|
||||
{device::SwapActionType::kHBM2DDR, {device::StorageType::kDevice, device::StorageType::kHost}},
|
||||
{device::SwapActionType::kDDR2HBM, {device::StorageType::kHost, device::StorageType::kDevice}},
|
||||
{device::SwapActionType::kDDR2DISK, {device::StorageType::kHost, device::StorageType::kFile}},
|
||||
{device::SwapActionType::kDISK2DDR, {device::StorageType::kFile, device::StorageType::kHost}},
|
||||
{device::SwapActionType::kHBM2DISK, {device::StorageType::kDevice, device::StorageType::kFile}},
|
||||
{device::SwapActionType::kDISK2HBM, {device::StorageType::kFile, device::StorageType::kDevice}}};
|
||||
static std::map<device::SwapActionType, device::StorageType> swap_to_map = {
|
||||
{device::SwapActionType::kHBM2DDR, device::StorageType::kHost},
|
||||
{device::SwapActionType::kDDR2HBM, device::StorageType::kDevice},
|
||||
{device::SwapActionType::kDDR2DISK, device::StorageType::kFile},
|
||||
{device::SwapActionType::kDISK2DDR, device::StorageType::kHost},
|
||||
{device::SwapActionType::kHBM2DISK, device::StorageType::kFile},
|
||||
{device::SwapActionType::kDISK2HBM, device::StorageType::kDevice}};
|
||||
UpdateDeviceTensors(context);
|
||||
for (const auto &action : swap_actions_) {
|
||||
const auto action_type = action.first;
|
||||
const auto &device_tensor_indexes = action.second;
|
||||
std::vector<DeviceTensor *> device_tensors;
|
||||
GetDeviceTensors(device_tensor_indexes, &device_tensors);
|
||||
const auto &device_tensors = GetDeviceTensors(device_tensor_indexes);
|
||||
if (action_type == device::SwapActionType::kAllocHBM) {
|
||||
AllocDeviceContinuousMem(device_tensors);
|
||||
} else if (action_type != device::SwapActionType::kUnDefined) {
|
||||
const auto from = swap_from_to_map[action_type].first;
|
||||
const auto to = swap_from_to_map[action_type].second;
|
||||
Swap(from, to, device_tensors);
|
||||
Swap(swap_to_map[action_type], device_tensors);
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Unknown swap action type, skip.";
|
||||
}
|
||||
|
|
|
@ -52,9 +52,9 @@ class MemorySwapActor : public AbstractActor {
|
|||
|
||||
private:
|
||||
void AllocDeviceContinuousMem(const std::vector<DeviceTensor *> &device_tensors);
|
||||
static void Swap(device::StorageType from, device::StorageType to, const std::vector<DeviceTensor *> &device_tensors);
|
||||
static void Swap(device::StorageType to, const std::vector<DeviceTensor *> &device_tensors);
|
||||
void UpdateDeviceTensors(OpContext<DeviceTensor> *context);
|
||||
void GetDeviceTensors(std::vector<size_t> indexes, std::vector<DeviceTensor *> *device_tensors);
|
||||
std::vector<DeviceTensor *> GetDeviceTensors(const std::vector<size_t> &indexes);
|
||||
|
||||
protected:
|
||||
size_t stream_id_;
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
namespace mindspore {
|
||||
namespace runtime {
|
||||
constexpr size_t kFirstVirtualNode = 0;
|
||||
constexpr size_t kSecondVirtualNodeOffset = 1;
|
||||
namespace {
|
||||
AbstractActor *GetCtrlActor(const ControlNodeParserPtr &parser, const KernelGraph *graph, const string &actor_suffix) {
|
||||
MS_EXCEPTION_IF_NULL(parser);
|
||||
|
@ -55,10 +56,12 @@ std::map<size_t, size_t> GetActionTensors(const std::shared_ptr<device::SwapActi
|
|||
std::map<size_t, size_t> tensor_indexes;
|
||||
std::set<size_t> is_real_parameter;
|
||||
for (const auto &tensor_action : swap_action->actions_) {
|
||||
MS_EXCEPTION_IF_NULL(tensor_action);
|
||||
if (tensor_action->tensor_id_ >= swap_strategy->tensor_infos_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid tensor id " << tensor_action->tensor_id_;
|
||||
}
|
||||
const auto &tensor_info = swap_strategy->tensor_infos_[tensor_action->tensor_id_];
|
||||
MS_EXCEPTION_IF_NULL(tensor_info);
|
||||
std::vector<size_t> real_tensor_ids;
|
||||
if (tensor_info->fused_tensor_ids_.empty()) {
|
||||
real_tensor_ids.emplace_back(tensor_info->tensor_id_);
|
||||
|
@ -73,6 +76,7 @@ std::map<size_t, size_t> GetActionTensors(const std::shared_ptr<device::SwapActi
|
|||
MS_LOG(EXCEPTION) << "Invalid tensor id " << real_tensor_id;
|
||||
}
|
||||
const auto &real_tensor_info = swap_strategy->tensor_infos_[real_tensor_id];
|
||||
MS_EXCEPTION_IF_NULL(real_tensor_info);
|
||||
const auto &node = real_tensor_info->node_;
|
||||
const auto &real_parameter_iter = real_parameters.find(node);
|
||||
if (real_parameter_iter == real_parameters.end()) {
|
||||
|
@ -105,14 +109,16 @@ void GenActionIndexList(const std::map<size_t, size_t> &tensors_id_index_map,
|
|||
std::map<device::SwapActionType, vector<size_t>> move_action_map;
|
||||
const auto &actions = swap_action->actions_;
|
||||
for (const auto &tensor_action : actions) {
|
||||
MS_EXCEPTION_IF_NULL(tensor_action);
|
||||
if (tensor_action->tensor_id_ >= swap_strategy->tensor_infos_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid tensor id " << tensor_action->tensor_id_;
|
||||
}
|
||||
const auto &tensor_info = swap_strategy->tensor_infos_[tensor_action->tensor_id_];
|
||||
MS_EXCEPTION_IF_NULL(tensor_info);
|
||||
if (tensor_action->action_ == device::SwapActionType::kAllocHBM) {
|
||||
std::vector<size_t> indexes;
|
||||
std::transform(tensor_info->fused_tensor_ids_.begin(), tensor_info->fused_tensor_ids_.end(),
|
||||
std::back_inserter(indexes), [](size_t index) { return index; });
|
||||
std::copy(tensor_info->fused_tensor_ids_.begin(), tensor_info->fused_tensor_ids_.end(),
|
||||
std::back_inserter(indexes));
|
||||
alloc_action_map.emplace_back(indexes);
|
||||
} else if (tensor_action->action_ != device::SwapActionType::kUnDefined) {
|
||||
const auto tensor_id = tensor_info->tensor_id_;
|
||||
|
@ -205,7 +211,7 @@ std::vector<std::vector<MemSwapActorPtr>> MemSwapScheduler::Build(const GraphCom
|
|||
const auto device_context = graph_compiler_info.device_contexts_[i];
|
||||
const auto &graph = graph_compiler_info.graphs_[i];
|
||||
std::vector<MemSwapActorPtr> actors;
|
||||
if (device_context == nullptr || graph->is_dynamic_shape()) {
|
||||
if (device_context == nullptr || graph == nullptr || graph->is_dynamic_shape()) {
|
||||
swap_actors.emplace_back(actors);
|
||||
continue;
|
||||
}
|
||||
|
@ -228,7 +234,7 @@ AbstractActor *MemSwapScheduler::GetActorForLink(size_t id, const std::shared_pt
|
|||
if (ret == nullptr) {
|
||||
ret = actor_set->data_prepare_actor_.get();
|
||||
}
|
||||
} else if (id == graph->execution_order().size()) {
|
||||
} else if (id == graph->execution_order().size() + kSecondVirtualNodeOffset) {
|
||||
ret = dynamic_cast<ExitActor *>(GetCtrlActor(parser, graph.get(), kExitActorNameSuffix));
|
||||
if (ret == nullptr) {
|
||||
ret = actor_set->loop_count_actor_.get();
|
||||
|
@ -258,7 +264,9 @@ void MemSwapScheduler::Link(const GraphCompilerInfo &graph_compiler_info, ActorS
|
|||
continue;
|
||||
}
|
||||
const auto &strategy = strategy_iter->second;
|
||||
MS_EXCEPTION_IF_NULL(strategy);
|
||||
for (const auto &link : strategy->links_) {
|
||||
MS_EXCEPTION_IF_NULL(link);
|
||||
const auto from_actor =
|
||||
GetActorForLink(link->from_, strategy, graph, graph_compiler_info.control_node_parser_, actor_set);
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include <map>
|
||||
#include "runtime/hardware/device_type.h"
|
||||
#include "runtime/device/device_address.h"
|
||||
#include "runtime/device/gsm/swap_manager.h"
|
||||
#include "runtime/collective/collective_communication_lib.h"
|
||||
#include "runtime/collective/collective_comm_lib_loader.h"
|
||||
#include "backend/common/session/kernel_graph.h"
|
||||
|
@ -160,6 +161,8 @@ class BACKEND_EXPORT DeviceResManager {
|
|||
// Return collective communication object for caller to access
|
||||
CollectiveCommunicationLib *collective_comm_lib() const { return collective_comm_lib_; }
|
||||
|
||||
std::shared_ptr<SwapManager> swap_manager() const { return swap_manager_; }
|
||||
|
||||
protected:
|
||||
// Ensure the thread safety for allocating device memory.
|
||||
mutable std::mutex alloc_mem_mutex_;
|
||||
|
@ -169,6 +172,8 @@ class BACKEND_EXPORT DeviceResManager {
|
|||
|
||||
DeviceContext *device_context_{nullptr};
|
||||
|
||||
std::shared_ptr<SwapManager> swap_manager_{nullptr};
|
||||
|
||||
private:
|
||||
template <class... Args>
|
||||
friend class DeviceInterface;
|
||||
|
|
Loading…
Reference in New Issue