DeviceAddress::MoveTo and Add SwapManager

This commit is contained in:
tanghuikang 2023-02-13 20:59:39 +08:00
parent 0101b59ec2
commit 8569480db6
14 changed files with 928 additions and 111 deletions

View File

@ -3,7 +3,7 @@ file(GLOB_RECURSE DEVICE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "common/*
"memory_manager.cc" "kernel_runtime_manager.cc" "convert_tensor_utils.cc" "memory_scheduler.cc"
"memory_offload_strategy.cc" "launch_kernel.cc" "launch_mul.cc" "tensor_array.cc"
"ms_device_shape_transfer.cc" "context_extends.cc" "stream_synchronizer.cc" "tensors_queue.cc" "auto_mem_offload.cc"
"common_somas_allocator.cc" "device_address_utils.cc"
"common_somas_allocator.cc" "device_address_utils.cc" "loadable_device_address.cc"
)
if("${ENABLE_HIDDEN}" STREQUAL "OFF" AND NOT MSVC)

View File

@ -217,7 +217,7 @@ class DeviceAddress : public mindspore::DeviceSync {
virtual bool Load(size_t) { MS_LOG(EXCEPTION) << "Not implemented."; }
// Move data to destination hardware and free resource on source hardware
virtual bool MoveTo(StorageType, bool) { MS_LOG(EXCEPTION) << "Not implemented."; }
virtual bool MoveTo(StorageType, bool, size_t) { MS_LOG(EXCEPTION) << "Not implemented."; }
virtual bool Wait() const { MS_LOG(EXCEPTION) << "Not implemented."; }
@ -230,6 +230,10 @@ class DeviceAddress : public mindspore::DeviceSync {
virtual void SetStorageInfo(const StorageInfo &) {}
virtual void HandOver(DeviceAddress *other) {
MS_EXCEPTION_IF_NULL(other);
if (other == this) {
return;
}
other->set_ptr(GetMutablePtr());
other->set_from_mem_pool(from_mem_pool());
set_ptr(nullptr);

View File

@ -0,0 +1,38 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GSM_HOST_MEM_POOL_H_
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GSM_HOST_MEM_POOL_H_
#include <memory>
#include "include/backend/visible.h"
namespace mindspore {
namespace device {
class HostMemPool {
public:
explicit BACKEND_EXPORT HostMemPool(size_t max_size) : max_size_(max_size) {}
~HostMemPool() = default;
virtual void *Malloc(size_t size) { return nullptr; }
virtual void Free(const void *ptr) {}
protected:
size_t max_size_;
};
using HostMemPoolPtr = std::shared_ptr<HostMemPool>;
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_GSM_HOST_MEM_POOL_H_

View File

@ -0,0 +1,91 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/gsm/io_handle.h"
#include <memory>
#include "utils/system/env.h"
namespace mindspore {
namespace device {
constexpr size_t kFileHeadOffset = 0;
bool IOHandle::Read(const std::string &file_name, void *data, size_t byte_num) {
if (aio_ != nullptr) {
return aio_->Read(file_name, data, byte_num);
}
const auto &fs = system::Env::GetFileSystem();
MS_EXCEPTION_IF_NULL(fs);
auto file = fs->CreateWriteFile(GetSwapFileWholeName(file_name));
MS_EXCEPTION_IF_NULL(file);
return file->PRead(data, byte_num, kFileHeadOffset);
}
bool IOHandle::Write(const std::string &file_name, const void *data, size_t byte_num) {
if (aio_ != nullptr) {
return aio_->Write(file_name, data, byte_num);
}
const auto &fs = system::Env::GetFileSystem();
MS_EXCEPTION_IF_NULL(fs);
auto file = fs->CreateWriteFile(GetSwapFileWholeName(file_name));
MS_EXCEPTION_IF_NULL(file);
return file->PWrite(data, byte_num, kFileHeadOffset);
}
bool IOHandle::ReadAsync(const std::string &file_name, void *data, size_t byte_num, AsyncIOToken *token) {
if (aio_ != nullptr) {
return aio_->ReadAsync(file_name, data, byte_num, token);
}
const auto &fs = system::Env::GetFileSystem();
MS_EXCEPTION_IF_NULL(fs);
auto file = fs->CreateWriteFile(GetSwapFileWholeName(file_name));
MS_EXCEPTION_IF_NULL(file);
*token = kInvalidAsyncIOToken;
return file->PRead(data, byte_num, kFileHeadOffset);
}
bool IOHandle::WriteAsync(const std::string &file_name, const void *data, size_t byte_num, AsyncIOToken *token) {
if (aio_ != nullptr) {
return aio_->WriteAsync(file_name, data, byte_num, token);
}
const auto &fs = system::Env::GetFileSystem();
MS_EXCEPTION_IF_NULL(fs);
auto file = fs->CreateWriteFile(GetSwapFileWholeName(file_name));
MS_EXCEPTION_IF_NULL(file);
*token = kInvalidAsyncIOToken;
return file->PWrite(data, byte_num, kFileHeadOffset);
}
bool IOHandle::Wait(AsyncIOToken token) { return aio_ == nullptr || aio_->Wait(token); }
bool IOHandle::DeleteSwapFile(const std::string &file_name) const {
const auto &fs = system::Env::GetFileSystem();
MS_EXCEPTION_IF_NULL(fs);
return fs->DeleteFile(GetSwapFileWholeName(file_name));
}
bool IOHandle::CreateSwapFile(const std::string &file_name) const {
const auto &fs = system::Env::GetFileSystem();
MS_EXCEPTION_IF_NULL(fs);
auto file = fs->CreateWriteFile(GetSwapFileWholeName(file_name));
MS_EXCEPTION_IF_NULL(file);
return true;
}
std::string IOHandle::GetSwapFileWholeName(const std::string &file_name) const {
return swap_path_ + file_name + ".data";
}
} // namespace device
} // namespace mindspore

View File

@ -0,0 +1,75 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_UTILS_SYSTEM_IO_HANDLE_H_
#define MINDSPORE_CORE_UTILS_SYSTEM_IO_HANDLE_H_
#include <memory>
#include <mutex>
#include <queue>
#include <string>
#include <utility>
#include "include/backend/visible.h"
namespace mindspore {
namespace device {
using AsyncIOToken = size_t;
constexpr AsyncIOToken kInvalidAsyncIOToken = 0;
struct AsyncIOConf {
size_t block_size;
size_t queue_depth;
};
class AsyncIO {
public:
AsyncIO() = default;
virtual ~AsyncIO() = default;
virtual bool Init(const AsyncIOConf &conf) = 0;
virtual bool Read(const std::string &file_name, void *data, size_t byte_num) = 0;
virtual bool Write(const std::string &file_name, const void *data, size_t byte_num) = 0;
virtual bool ReadAsync(const std::string &file_name, void *data, size_t byte_num, AsyncIOToken *token) = 0;
virtual bool WriteAsync(const std::string &file_name, const void *data, size_t byte_num, AsyncIOToken *token) = 0;
virtual bool Wait(AsyncIOToken token) = 0;
};
class BACKEND_EXPORT IOHandle {
public:
IOHandle() = default;
~IOHandle() = default;
bool DeleteSwapFile(const std::string &file_name) const;
bool CreateSwapFile(const std::string &file_name) const;
bool Read(const std::string &file_name, void *data, size_t byte_num);
bool Write(const std::string &file_name, const void *data, size_t byte_num);
bool ReadAsync(const std::string &file_name, void *data, size_t byte_num, AsyncIOToken *token);
bool WriteAsync(const std::string &file_name, const void *data, size_t byte_num, AsyncIOToken *token);
bool Wait(AsyncIOToken sync_token);
private:
std::string GetSwapFileWholeName(const std::string &file_name) const;
private:
bool use_aio_{false};
size_t aio_block_size_{0};
size_t aio_queue_depth_{0};
std::string swap_path_;
std::shared_ptr<AsyncIO> aio_{nullptr};
};
using IOHandlePtr = std::shared_ptr<IOHandle>;
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_CORE_UTILS_SYSTEM_IO_HANDLE_H_

View File

@ -0,0 +1,190 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/gsm/swap_manager.h"
#include <string>
#include <utility>
namespace mindspore {
namespace device {
SwapManager::SwapManager(size_t stream_id, mindspore::device::DynamicMemPoolBestFit *device_memory_pool,
HostMemPoolPtr host_mem_pool)
: stream_id_(stream_id), device_memory_pool_(device_memory_pool), host_memory_pool_(std::move(host_mem_pool)) {
io_handle_ = std::make_shared<IOHandle>();
}
template <class Input, class Output>
bool SwapManager::TryAllocate(std::queue<const DeviceAddress *> queue, const Input &input,
Output (SwapManager::*allocate_func)(const Input &),
const std::function<bool(Output)> &success, Output *output) {
MS_EXCEPTION_IF_NULL(allocate_func);
MS_EXCEPTION_IF_NULL(output);
(*output) = (this->*allocate_func)(input);
if (success(*output)) {
return true;
}
// Wait swapping tensors.
while (!queue.empty()) {
const auto &front = queue.front();
MS_EXCEPTION_IF_NULL(front);
if (front->Wait()) {
(*output) = (this->*allocate_func)(input);
if (success(*output)) {
return true;
}
}
queue.pop();
}
return false;
}
void *SwapManager::AllocDeviceMemorySimply(const size_t &size) {
MS_EXCEPTION_IF_NULL(device_memory_pool_);
return device_memory_pool_->AllocTensorMem(size);
}
void *SwapManager::AllocDeviceMemory(size_t size) {
void *ret = nullptr;
void *(SwapManager::*allocate_func)(const size_t &) = &SwapManager::AllocDeviceMemorySimply;
std::function<bool(void *)> success = [](void *ptr) { return ptr != nullptr; };
std::lock_guard<std::mutex> lock(swapping_tensors_device_mutex_);
if (!TryAllocate(swapping_tensors_device_, size, allocate_func, success, &ret)) {
MS_LOG(WARNING) << "Allocate device memory failed, size: " << size;
}
return ret;
}
std::vector<void *> SwapManager::AllocDeviceContinuousMemSimply(const std::vector<size_t> &size_list) {
MS_EXCEPTION_IF_NULL(device_memory_pool_);
return device_memory_pool_->AllocContinuousTensorMem(size_list);
}
std::vector<void *> SwapManager::AllocDeviceContinuousMem(const std::vector<size_t> &size_list) {
std::vector<void *> ret;
std::vector<void *> (SwapManager::*allocate_func)(const std::vector<size_t> &) =
&SwapManager::AllocDeviceContinuousMemSimply;
std::function<bool(std::vector<void *>)> success = [](const std::vector<void *> &ptrs) { return !ptrs.empty(); };
std::lock_guard<std::mutex> lock(swapping_tensors_device_mutex_);
if (!TryAllocate(swapping_tensors_device_, size_list, allocate_func, success, &ret)) {
MS_LOG(WARNING) << "Allocate continuous device mem failed, size list: " << size_list;
}
return ret;
}
void SwapManager::FreeDeviceMemory(void *ptr) {
MS_EXCEPTION_IF_NULL(device_memory_pool_);
device_memory_pool_->FreeTensorMem(ptr);
}
void *SwapManager::AllocHostMemorySimply(const size_t &size) {
MS_EXCEPTION_IF_NULL(host_memory_pool_);
return host_memory_pool_->Malloc(size);
}
void *SwapManager::AllocHostMemory(size_t size) {
void *ret = nullptr;
void *(SwapManager::*allocate_func)(const size_t &) = &SwapManager::AllocHostMemorySimply;
std::function<bool(void *)> success = [](void *ptr) { return ptr != nullptr; };
std::lock_guard<std::mutex> lock(swapping_tensors_host_mutex_);
if (!TryAllocate(swapping_tensors_host_, size, allocate_func, success, &ret)) {
MS_LOG(WARNING) << "Allocate host memory failed, size: " << size;
}
return ret;
}
void SwapManager::FreeHostMemory(void *ptr) {
MS_EXCEPTION_IF_NULL(host_memory_pool_);
host_memory_pool_->Free(ptr);
}
bool SwapManager::CreateFile(const std::string &file_name) {
MS_EXCEPTION_IF_NULL(io_handle_);
return io_handle_->CreateSwapFile(file_name);
}
bool SwapManager::DeleteFile(const std::string &file_name) {
MS_EXCEPTION_IF_NULL(io_handle_);
const auto &iter = file_size_.find(file_name);
if (iter == file_size_.end()) {
MS_LOG(WARNING) << "Can not file size for file[" << file_name << "]";
} else {
current_used_file_size_ -= iter->second;
iter->second = 0;
}
return io_handle_->DeleteSwapFile(file_name);
}
bool SwapManager::FileToHostMemory(void *host_memory, const std::string &file_name, size_t byte_num, bool async,
AsyncIOToken *sync_key) {
MS_EXCEPTION_IF_NULL(io_handle_);
if (async) {
return io_handle_->ReadAsync(file_name, host_memory, byte_num, sync_key);
} else {
return io_handle_->Read(file_name, host_memory, byte_num);
}
}
bool SwapManager::EnoughFileSpace(const size_t &size) { return current_used_file_size_ + size <= max_file_size_; }
bool SwapManager::HostMemoryToFile(const std::string &file_name, const void *data, size_t byte_num, bool async,
AsyncIOToken *sync_key) {
MS_EXCEPTION_IF_NULL(io_handle_);
bool (SwapManager::*allocate_func)(const size_t &size) = &SwapManager::EnoughFileSpace;
std::function<bool(bool)> success = [](bool ret) { return ret; };
{
std::lock_guard<std::mutex> lock(swapping_tensors_file_mutex_);
bool enough = false;
if (!TryAllocate(swapping_tensors_file_, byte_num, allocate_func, success, &enough)) {
MS_LOG(WARNING) << "There is no enough disk space for file writing, size: " << byte_num;
return false;
}
}
current_used_file_size_ += byte_num;
file_size_[file_name] = byte_num;
if (async) {
return io_handle_->WriteAsync(file_name, data, byte_num, sync_key);
} else {
return io_handle_->Write(file_name, data, byte_num);
}
}
bool SwapManager::WaitAsyncIO(mindspore::device::AsyncIOToken sync_token) {
MS_EXCEPTION_IF_NULL(io_handle_);
return io_handle_->Wait(sync_token);
}
void SwapManager::AddSwappableTensor(const DeviceAddressPtr &device_address) {
swappable_tensors_.push(device_address);
}
void SwapManager::AddSwappingTensor(const mindspore::device::DeviceAddress *device_address) {
if (device_address == nullptr) {
return;
}
if (device_address->status() == DeviceAddressStatus::kInFileToHost) {
std::lock_guard<std::mutex> lock(swapping_tensors_file_mutex_);
(void)swapping_tensors_file_.push(device_address);
} else if (device_address->status() == DeviceAddressStatus::kInDeviceToHost) {
std::lock_guard<std::mutex> lock(swapping_tensors_device_mutex_);
(void)swapping_tensors_device_.push(device_address);
} else {
std::lock_guard<std::mutex> lock(swapping_tensors_host_mutex_);
(void)swapping_tensors_host_.push(device_address);
}
}
} // namespace device
} // namespace mindspore

View File

@ -0,0 +1,91 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GSM_SWAP_MANAGER_H_
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GSM_SWAP_MANAGER_H_
#include <memory>
#include <queue>
#include <string>
#include <vector>
#include "common/mem_reuse/mem_dynamic_allocator.h"
#include "runtime/device/device_address.h"
#include "runtime/device/gsm/io_handle.h"
#include "runtime/device/gsm/host_mem_pool.h"
#include "include/backend/visible.h"
namespace mindspore {
namespace device {
class BACKEND_EXPORT SwapManager {
public:
SwapManager(size_t stream_id, DynamicMemPoolBestFit *device_memory_pool, HostMemPoolPtr host_memory_pool);
~SwapManager() = default;
// Device memory
void *AllocDeviceMemory(size_t size);
std::vector<void *> AllocDeviceContinuousMem(const std::vector<size_t> &size_list);
void FreeDeviceMemory(void *ptr);
// Host memory
void *AllocHostMemory(size_t size);
void FreeHostMemory(void *ptr);
// File
bool CreateFile(const std::string &file_name);
bool DeleteFile(const std::string &file_name);
bool FileToHostMemory(void *host_memory, const std::string &file_name, size_t byte_num, bool async,
AsyncIOToken *sync_token);
bool HostMemoryToFile(const std::string &file_name, const void *data, size_t byte_num, bool async,
AsyncIOToken *sync_token);
bool WaitAsyncIO(AsyncIOToken sync_token);
// Swapping and swappable tensors
void AddSwappableTensor(const DeviceAddressPtr &device_address);
void AddSwappingTensor(const DeviceAddress *device_address);
private:
void *AllocDeviceMemorySimply(const size_t &size);
std::vector<void *> AllocDeviceContinuousMemSimply(const std::vector<size_t> &size_list);
void *AllocHostMemorySimply(const size_t &size);
bool EnoughFileSpace(const size_t &size);
template <class Input, class Output>
bool TryAllocate(std::queue<const DeviceAddress *> queue, const Input &input,
Output (SwapManager::*allocate_func)(const Input &), const std::function<bool(Output)> &success,
Output *output);
private:
size_t stream_id_;
DynamicMemPoolBestFit *device_memory_pool_;
HostMemPoolPtr host_memory_pool_;
size_t max_file_size_{0};
size_t current_used_file_size_{0};
HashMap<std::string, size_t> file_size_;
struct compare {
bool operator()(const DeviceAddressPtr &l, const DeviceAddressPtr &r) const { return l->GetSize() < r->GetSize(); }
};
std::priority_queue<DeviceAddressPtr, std::vector<DeviceAddressPtr>, compare> swappable_tensors_;
std::mutex swapping_tensors_device_mutex_;
std::queue<const DeviceAddress *> swapping_tensors_device_;
std::mutex swapping_tensors_host_mutex_;
std::queue<const DeviceAddress *> swapping_tensors_host_;
std::mutex swapping_tensors_file_mutex_;
std::queue<const DeviceAddress *> swapping_tensors_file_;
IOHandlePtr io_handle_;
};
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_GSM_SWAP_MANAGER_H_

View File

@ -20,9 +20,10 @@
#include <queue>
#include "runtime/device/gsm/swap_strategy.h"
#include "runtime/device/gsm/mem_usage_analyzer.h"
#include "include/backend/visible.h"
namespace mindspore {
namespace device {
class SwapStrategyBuilder {
class BACKEND_EXPORT SwapStrategyBuilder {
public:
SwapStrategyBuilder() = default;
~SwapStrategyBuilder() = default;

View File

@ -0,0 +1,324 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/loadable_device_address.h"
#include "include/common/debug/common.h"
namespace mindspore {
namespace device {
namespace {
constexpr size_t kFileAlignSize = 512;
} // namespace
bool LoadableDeviceAddress::Offload(size_t stream_id) {
std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
if (mem_offloaded_) {
MS_LOG(WARNING) << "Trying to offload an offloaded AscendDeviceAddress.";
return true;
}
MS_EXCEPTION_IF_NULL(ptr_);
auto device_context = GetDeviceContext();
MS_EXCEPTION_IF_NULL(device_context);
offload_ptr_ = device_context->device_res_manager_->AllocateOffloadMemory(size_);
if (offload_ptr_ == nullptr) {
MS_LOG(EXCEPTION) << "Alloc host memory for offloading failed, size: " << size_ << ".";
}
if (!AsyncDeviceToHost({}, size_, kTypeUnknown, offload_ptr_, stream_id)) {
return false;
}
device_context->device_res_manager_->FreeMemory(ptr_);
ptr_ = nullptr;
mem_offloaded_ = true;
return true;
}
bool LoadableDeviceAddress::Load(size_t stream_id) {
std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
if (!mem_offloaded_) {
MS_LOG(DEBUG) << "Trying to load a loaded AscendDeviceAddress.";
return true;
}
MS_EXCEPTION_IF_NULL(offload_ptr_);
auto device_context = GetDeviceContext();
MS_EXCEPTION_IF_NULL(device_context);
if (ptr_ == nullptr && !device_context->device_res_manager_->AllocateMemory(this)) {
MS_LOG(EXCEPTION) << "Alloc memory for loading failed, size: " << size_ << ".";
}
MS_EXCEPTION_IF_NULL(ptr_);
if (!AsyncHostToDevice({}, size_, kTypeUnknown, offload_ptr_, stream_id)) {
return false;
}
device_context->device_res_manager_->FreeOffloadMemory(offload_ptr_);
offload_ptr_ = nullptr;
mem_offloaded_ = false;
return true;
}
bool LoadableDeviceAddress::MoveTo(mindspore::device::StorageType dst, bool async, size_t stream_id) {
bool ret = Wait();
if (!ret) {
MS_LOG(WARNING) << "Wait swapping DeviceAddress failed. Status: " << status_;
return false;
}
if (dst == StorageType::kDevice) {
if (!MoveToDevice(async, stream_id)) {
MS_LOG(WARNING) << "Move data to device failed.";
return false;
}
} else if (dst == StorageType::kHost) {
if (!MoveToHost(async, stream_id)) {
MS_LOG(WARNING) << "Move data to host failed.";
return false;
}
} else if (dst == StorageType::kFile) {
if (!MoveToFile(async, stream_id)) {
MS_LOG(WARNING) << "Move data to file failed.";
return false;
}
}
return true;
}
bool LoadableDeviceAddress::MoveToHost(bool async, size_t stream_id) const {
const auto device_context = GetDeviceContext();
MS_EXCEPTION_IF_NULL(device_context);
const auto swap_manager = device_context->device_res_manager_->swap_manager();
MS_EXCEPTION_IF_NULL(swap_manager);
storage_info_.host_ptr_ = swap_manager->AllocHostMemory(GetFileAlignSize());
MS_EXCEPTION_IF_NULL(storage_info_.host_ptr_);
if (status_ == DeviceAddressStatus::kInFile) {
if (!CopyFileToHost(storage_info_.host_ptr_, storage_info_.file_name_, GetFileAlignSize(), async)) {
MS_LOG(WARNING) << "Copy data from file to host failed.";
return false;
}
if (async) {
swap_manager->AddSwappingTensor(this);
status_ = DeviceAddressStatus::kInFileToHost;
} else {
swap_manager->DeleteFile(storage_info_.file_name_);
storage_info_.file_name_ = "";
status_ = DeviceAddressStatus::kInHost;
}
} else {
if (!CopyDeviceToHost(storage_info_.host_ptr_, ptr_, size_, async, stream_id)) {
MS_LOG(WARNING) << "Copy data from device to host failed.";
return false;
}
if (async) {
swap_manager->AddSwappingTensor(this);
status_ = DeviceAddressStatus::kInDeviceToHost;
} else {
swap_manager->FreeDeviceMemory(ptr_);
ptr_ = nullptr;
status_ = DeviceAddressStatus::kInHost;
}
}
return true;
}
bool LoadableDeviceAddress::MoveToDevice(bool async, size_t stream_id) const {
if (status_ == DeviceAddressStatus::kInDevice) {
return true;
}
const auto device_context = GetDeviceContext();
MS_EXCEPTION_IF_NULL(device_context);
const auto swap_manager = device_context->device_res_manager_->swap_manager();
MS_EXCEPTION_IF_NULL(swap_manager);
if (status_ == DeviceAddressStatus::kInFile && !MoveToHost(false, stream_id)) {
return false;
}
ptr_ = swap_manager->AllocDeviceMemory(size_);
MS_EXCEPTION_IF_NULL(ptr_);
if (!CopyHostToDevice(ptr_, storage_info_.host_ptr_, size_, async, stream_id)) {
MS_LOG(WARNING) << "Copy data from host to device failed.";
return false;
}
if (async) {
swap_manager->AddSwappingTensor(this);
status_ = DeviceAddressStatus::kInHostToDevice;
} else {
swap_manager->FreeHostMemory(storage_info_.host_ptr_);
storage_info_.host_ptr_ = nullptr;
status_ = DeviceAddressStatus::kInDevice;
}
return true;
}
bool LoadableDeviceAddress::MoveToFile(bool async, size_t stream_id) const {
if (status_ == DeviceAddressStatus::kInFile) {
return true;
}
const auto device_context = GetDeviceContext();
MS_EXCEPTION_IF_NULL(device_context);
const auto swap_manager = device_context->device_res_manager_->swap_manager();
MS_EXCEPTION_IF_NULL(swap_manager);
if (status_ == DeviceAddressStatus::kInDevice && !MoveToHost(false, stream_id)) {
return false;
}
storage_info_.file_name_ = GetSwapFileName();
if (!swap_manager->CreateFile(storage_info_.file_name_)) {
MS_LOG(WARNING) << "Create file for swapping failed.";
return false;
}
if (!CopyHostToFile(storage_info_.file_name_, storage_info_.host_ptr_, GetFileAlignSize(), async)) {
MS_LOG(WARNING) << "Copy data from host to file failed.";
return false;
}
if (async) {
swap_manager->AddSwappingTensor(this);
status_ = DeviceAddressStatus::kInHostToFile;
} else {
swap_manager->FreeHostMemory(storage_info_.host_ptr_);
storage_info_.host_ptr_ = nullptr;
status_ = DeviceAddressStatus::kInFile;
}
return true;
}
bool LoadableDeviceAddress::CopyHostToFile(const std::string &dst, const void *src, size_t size, bool async) const {
MS_EXCEPTION_IF_NULL(src);
const auto device_context = GetDeviceContext();
MS_EXCEPTION_IF_NULL(device_context);
const auto swap_manager = device_context->device_res_manager_->swap_manager();
MS_EXCEPTION_IF_NULL(swap_manager);
AsyncIOToken token;
bool ret = swap_manager->HostMemoryToFile(dst, src, size, async, &token);
if (!ret) {
MS_LOG(WARNING) << "Write data from ddr to file[" << dst << "] failed.";
return ret;
}
if (async) {
MS_EXCEPTION_IF_NULL(swap_event_);
swap_event_->aio_token_ = token;
}
return ret;
}
bool LoadableDeviceAddress::CopyFileToHost(void *dst, const std::string &src, size_t size, bool async) const {
MS_EXCEPTION_IF_NULL(dst);
const auto device_context = GetDeviceContext();
MS_EXCEPTION_IF_NULL(device_context);
const auto swap_manager = device_context->device_res_manager_->swap_manager();
MS_EXCEPTION_IF_NULL(swap_manager);
AsyncIOToken token;
bool ret = swap_manager->FileToHostMemory(dst, src, size, async, &token);
if (!ret) {
MS_LOG(WARNING) << "Read data from file[" << src << "] to ddr failed.";
return ret;
}
if (async) {
MS_EXCEPTION_IF_NULL(swap_event_);
swap_event_->aio_token_ = token;
}
return true;
}
size_t LoadableDeviceAddress::GetFileAlignSize() const {
return (size_ + kFileAlignSize - 1) / kFileAlignSize * kFileAlignSize;
}
std::string LoadableDeviceAddress::GetSwapFileName() const {
static size_t swap_file_index = 0;
return std::to_string(device_id()) + "_" + std::to_string(swap_file_index++) + "_" +
std::to_string(Common::GetTimeStamp());
}
void LoadableDeviceAddress::SetStorageInfo(const mindspore::device::StorageInfo &storage_info) {
std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
storage_info_ = storage_info;
if (storage_info_.host_ptr_ != nullptr) {
status_ = DeviceAddressStatus::kInHost;
} else if (!storage_info_.file_name_.empty()) {
status_ = DeviceAddressStatus::kInFile;
} else {
status_ = DeviceAddressStatus::kInDevice;
}
}
void LoadableDeviceAddress::HandOver(mindspore::device::DeviceAddress *other) {
DeviceAddress::HandOver(other);
auto loadable_device_address = reinterpret_cast<LoadableDeviceAddress *>(other);
if (loadable_device_address != nullptr) {
loadable_device_address->SetStorageInfo(storage_info_);
loadable_device_address->set_status(status_);
storage_info_.host_ptr_ = nullptr;
storage_info_.file_name_ = "";
status_ = DeviceAddressStatus::kInDevice;
}
}
bool LoadableDeviceAddress::Wait() const {
if (swap_event_ == nullptr || !swap_event_->NeedWait()) {
return true;
}
const auto device_context = GetDeviceContext();
MS_EXCEPTION_IF_NULL(device_context);
const auto swap_manager = device_context->device_res_manager_->swap_manager();
MS_EXCEPTION_IF_NULL(swap_manager);
if (swap_event_->device_event_ != nullptr && swap_event_->device_event_->NeedWait()) {
swap_event_->device_event_->WaitEvent();
} else if (swap_event_->aio_token_ != kInvalidAsyncIOToken) {
if (!swap_manager->WaitAsyncIO(swap_event_->aio_token_)) {
MS_LOG(WARNING) << "Wait aio failed.";
return false;
}
} else {
MS_LOG(WARNING) << "Device address is in moving, but no valid swap event can be found.";
}
if (status_ == DeviceAddressStatus::kInFileToHost) {
swap_manager->DeleteFile(storage_info_.file_name_);
status_ = DeviceAddressStatus::kInHost;
} else if (status_ == DeviceAddressStatus::kInDeviceToHost) {
swap_manager->FreeDeviceMemory(ptr_);
status_ = DeviceAddressStatus::kInHost;
} else {
swap_manager->FreeHostMemory(storage_info_.host_ptr_);
if (status_ == DeviceAddressStatus::kInHostToDevice) {
status_ = DeviceAddressStatus::kInHost;
} else {
status_ = DeviceAddressStatus::kInFile;
}
}
return true;
}
void LoadableDeviceAddress::SetOffloadPtr(void *offload_ptr) {
std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
offload_ptr_ = offload_ptr;
mem_offloaded_ = (offload_ptr != nullptr);
}
void *LoadableDeviceAddress::GetOffloadPtr() const {
std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
return offload_ptr_;
}
// Return whether DeviceAddress has a valid ptr.
bool LoadableDeviceAddress::IsPtrValid() const {
std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
return ptr_ != nullptr || offload_ptr_ != nullptr;
}
// Load first if data is offloaded and return the device ptr.
void *LoadableDeviceAddress::GetValidPtr(size_t stream_id) {
std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
if (mem_offloaded() && !Load(stream_id)) {
MS_LOG(EXCEPTION) << "Load offloaded memory failed";
}
return ptr_;
}
} // namespace device
} // namespace mindspore

View File

@ -17,6 +17,7 @@
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_LOADABLE_DEVICE_ADDRESS_H_
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_LOADABLE_DEVICE_ADDRESS_H_
#include <memory>
#include <string>
#include "runtime/device/device_address.h"
#include "runtime/hardware/device_context.h"
@ -24,108 +25,98 @@
namespace mindspore {
namespace device {
// LoadableDeviceAddress provide the ability to offload data on device to host and load it back later.
class LoadableDeviceAddress : public DeviceAddress {
struct SwapEvent {
bool NeedWait() {
return aio_token_ != kInvalidAsyncIOToken || (device_event_ != nullptr && device_event_->NeedWait());
}
AsyncIOToken aio_token_{kInvalidAsyncIOToken};
std::shared_ptr<DeviceEvent> device_event_;
};
using SwapEventPtr = std::shared_ptr<SwapEvent>;
// LoadableDeviceAddress provide the ability to offload data on device to ddr or disk and load it back later.
class BACKEND_EXPORT LoadableDeviceAddress : public DeviceAddress {
public:
LoadableDeviceAddress(void *ptr, size_t size) : DeviceAddress(ptr, size) {}
LoadableDeviceAddress(void *ptr, size_t size) : DeviceAddress(ptr, size) {
swap_event_ = std::make_shared<SwapEvent>();
}
LoadableDeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id)
: DeviceAddress(ptr, size, format, type_id) {}
: DeviceAddress(ptr, size, format, type_id) {
swap_event_ = std::make_shared<SwapEvent>();
}
LoadableDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id,
const KernelWithIndex &node_index)
: DeviceAddress(ptr, size, format, type_id, node_index) {}
: DeviceAddress(ptr, size, format, type_id, node_index) {
swap_event_ = std::make_shared<SwapEvent>();
}
LoadableDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id,
const std::string &device_name, uint32_t device_id)
: DeviceAddress(ptr, size, format, type_id, device_name, device_id) {}
: DeviceAddress(ptr, size, format, type_id, device_name, device_id) {
swap_event_ = std::make_shared<SwapEvent>();
}
LoadableDeviceAddress(void *ptr, size_t size, const std::string &device_name, uint32_t device_id)
: DeviceAddress(ptr, size, device_name, device_id) {}
: DeviceAddress(ptr, size, device_name, device_id) {
swap_event_ = std::make_shared<SwapEvent>();
}
LoadableDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id,
const KernelWithIndex &node_index, const std::string &device_name, uint32_t device_id)
: DeviceAddress(ptr, size, format, type_id, node_index, device_name, device_id) {}
: DeviceAddress(ptr, size, format, type_id, node_index, device_name, device_id) {
swap_event_ = std::make_shared<SwapEvent>();
}
bool mem_offloaded() const final { return mem_offloaded_; }
// Offload data from device to host and free device memory
bool Offload(size_t stream_id) final;
// Load data from host to device and free host memory
bool Load(size_t stream_id) final;
// Move data to destination hardware and free resource on source hardware
bool MoveTo(StorageType dest, bool async, size_t stream_id) override;
bool Wait() const override;
void SetStorageInfo(const StorageInfo &storage_info) final;
// Set host ptr data offloaded to
void SetOffloadPtr(void *offload_ptr) final;
// Get offloaded host ptr
void *GetOffloadPtr() const final;
// Return whether DeviceAddress has a valid ptr.
bool IsPtrValid() const final;
// Load first if data is offloaded and return the device ptr.
void *GetValidPtr(size_t stream_id) final;
void HandOver(DeviceAddress *other) override;
protected:
DeviceContext *GetDeviceContext() const {
return DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name_, device_id_});
}
bool MoveToDevice(bool async, size_t stream_id) const;
bool MoveToHost(bool async, size_t stream_id) const;
bool MoveToFile(bool async, size_t stream_id) const;
virtual bool CopyDeviceToHost(void *dst, const void *src, size_t size, bool async, size_t stream_id) const {
return false;
}
virtual bool CopyHostToDevice(void *dst, const void *src, size_t size, bool async, size_t stream_id) const {
return false;
}
virtual bool CopyHostToFile(const std::string &dst, const void *src, size_t size, bool async) const;
virtual bool CopyFileToHost(void *dst, const std::string &src, size_t size, bool async) const;
std::string GetSwapFileName() const;
size_t GetFileAlignSize() const;
bool mem_offloaded_{false};
void *offload_ptr_{nullptr};
public:
bool mem_offloaded() const final { return mem_offloaded_; }
// Offload data from device to host and free device memory
bool Offload(size_t stream_id) final {
std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
if (mem_offloaded_) {
MS_LOG(WARNING) << "Trying to offload an offloaded AscendDeviceAddress.";
return true;
}
MS_EXCEPTION_IF_NULL(ptr_);
auto device_context = GetDeviceContext();
MS_EXCEPTION_IF_NULL(device_context);
offload_ptr_ = device_context->device_res_manager_->AllocateOffloadMemory(size_);
if (offload_ptr_ == nullptr) {
MS_LOG(EXCEPTION) << "Alloc host memory for offloading failed, size: " << size_ << ".";
}
if (!AsyncDeviceToHost({}, size_, kTypeUnknown, offload_ptr_, stream_id)) {
return false;
}
device_context->device_res_manager_->FreeMemory(ptr_);
ptr_ = nullptr;
mem_offloaded_ = true;
return true;
}
// Load data from host to device and free host memory
bool Load(size_t stream_id) final {
std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
if (!mem_offloaded_) {
MS_LOG(DEBUG) << "Trying to load a loaded AscendDeviceAddress.";
return true;
}
MS_EXCEPTION_IF_NULL(offload_ptr_);
auto device_context = GetDeviceContext();
MS_EXCEPTION_IF_NULL(device_context);
if (ptr_ == nullptr && !device_context->device_res_manager_->AllocateMemory(this)) {
MS_LOG(EXCEPTION) << "Alloc memory for loading failed, size: " << size_ << ".";
}
MS_EXCEPTION_IF_NULL(ptr_);
if (!AsyncHostToDevice({}, size_, kTypeUnknown, offload_ptr_, stream_id)) {
return false;
}
device_context->device_res_manager_->FreeOffloadMemory(offload_ptr_);
offload_ptr_ = nullptr;
mem_offloaded_ = false;
return true;
}
// Set host ptr data offloaded to
void SetOffloadPtr(void *offload_ptr) final {
std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
offload_ptr_ = offload_ptr;
mem_offloaded_ = (offload_ptr != nullptr);
}
// Get offloaded host ptr
void *GetOffloadPtr() const final {
std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
return offload_ptr_;
}
// Return whether DeviceAddress has a valid ptr.
bool IsPtrValid() const final {
std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
return ptr_ != nullptr || offload_ptr_ != nullptr;
}
// Load first if data is offloaded and return the device ptr.
void *GetValidPtr(size_t stream_id) final {
std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
if (mem_offloaded() && !Load(stream_id)) {
MS_LOG(EXCEPTION) << "Load offloaded memory failed.";
}
return ptr_;
}
SwapEventPtr swap_event_{nullptr};
mutable StorageInfo storage_info_;
};
} // namespace device
} // namespace mindspore

View File

@ -59,14 +59,16 @@ void MemorySwapActor::UpdateDeviceTensors(OpContext<mindspore::runtime::DeviceTe
}
}
void MemorySwapActor::GetDeviceTensors(std::vector<size_t> indexes, std::vector<DeviceTensor *> *device_tensors) {
std::vector<DeviceTensor *> MemorySwapActor::GetDeviceTensors(const std::vector<size_t> &indexes) {
std::vector<DeviceTensor *> device_tensors;
for (const auto index : indexes) {
if (index >= device_tensors_to_swap_.size()) {
MS_LOG(EXCEPTION) << "Device tensor index[" << index << "] out of range[" << device_tensors_to_swap_.size()
<< "].";
}
device_tensors->emplace_back(device_tensors_to_swap_[index]);
device_tensors.emplace_back(device_tensors_to_swap_[index]);
}
return std::move(device_tensors);
}
void MemorySwapActor::AllocDeviceContinuousMem(const std::vector<DeviceTensor *> &device_tensors) {
@ -85,34 +87,31 @@ void MemorySwapActor::AllocDeviceContinuousMem(const std::vector<DeviceTensor *>
}
}
void MemorySwapActor::Swap(device::StorageType from, device::StorageType to,
const std::vector<DeviceTensor *> &device_tensors) {
void MemorySwapActor::Swap(device::StorageType to, const std::vector<DeviceTensor *> &device_tensors) {
for (const auto &device_tensor : device_tensors) {
device_tensor->MoveTo(to, false);
MS_EXCEPTION_IF_NULL(device_tensor);
device_tensor->MoveTo(to, false, kDefaultStreamIndex);
}
}
void MemorySwapActor::Run(OpContext<mindspore::runtime::DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
static std::map<device::SwapActionType, std::pair<device::StorageType, device::StorageType>> swap_from_to_map = {
{device::SwapActionType::kHBM2DDR, {device::StorageType::kDevice, device::StorageType::kHost}},
{device::SwapActionType::kDDR2HBM, {device::StorageType::kHost, device::StorageType::kDevice}},
{device::SwapActionType::kDDR2DISK, {device::StorageType::kHost, device::StorageType::kFile}},
{device::SwapActionType::kDISK2DDR, {device::StorageType::kFile, device::StorageType::kHost}},
{device::SwapActionType::kHBM2DISK, {device::StorageType::kDevice, device::StorageType::kFile}},
{device::SwapActionType::kDISK2HBM, {device::StorageType::kFile, device::StorageType::kDevice}}};
static std::map<device::SwapActionType, device::StorageType> swap_to_map = {
{device::SwapActionType::kHBM2DDR, device::StorageType::kHost},
{device::SwapActionType::kDDR2HBM, device::StorageType::kDevice},
{device::SwapActionType::kDDR2DISK, device::StorageType::kFile},
{device::SwapActionType::kDISK2DDR, device::StorageType::kHost},
{device::SwapActionType::kHBM2DISK, device::StorageType::kFile},
{device::SwapActionType::kDISK2HBM, device::StorageType::kDevice}};
UpdateDeviceTensors(context);
for (const auto &action : swap_actions_) {
const auto action_type = action.first;
const auto &device_tensor_indexes = action.second;
std::vector<DeviceTensor *> device_tensors;
GetDeviceTensors(device_tensor_indexes, &device_tensors);
const auto &device_tensors = GetDeviceTensors(device_tensor_indexes);
if (action_type == device::SwapActionType::kAllocHBM) {
AllocDeviceContinuousMem(device_tensors);
} else if (action_type != device::SwapActionType::kUnDefined) {
const auto from = swap_from_to_map[action_type].first;
const auto to = swap_from_to_map[action_type].second;
Swap(from, to, device_tensors);
Swap(swap_to_map[action_type], device_tensors);
} else {
MS_LOG(WARNING) << "Unknown swap action type, skip.";
}

View File

@ -52,9 +52,9 @@ class MemorySwapActor : public AbstractActor {
private:
void AllocDeviceContinuousMem(const std::vector<DeviceTensor *> &device_tensors);
static void Swap(device::StorageType from, device::StorageType to, const std::vector<DeviceTensor *> &device_tensors);
static void Swap(device::StorageType to, const std::vector<DeviceTensor *> &device_tensors);
void UpdateDeviceTensors(OpContext<DeviceTensor> *context);
void GetDeviceTensors(std::vector<size_t> indexes, std::vector<DeviceTensor *> *device_tensors);
std::vector<DeviceTensor *> GetDeviceTensors(const std::vector<size_t> &indexes);
protected:
size_t stream_id_;

View File

@ -31,6 +31,7 @@
namespace mindspore {
namespace runtime {
constexpr size_t kFirstVirtualNode = 0;
constexpr size_t kSecondVirtualNodeOffset = 1;
namespace {
AbstractActor *GetCtrlActor(const ControlNodeParserPtr &parser, const KernelGraph *graph, const string &actor_suffix) {
MS_EXCEPTION_IF_NULL(parser);
@ -55,10 +56,12 @@ std::map<size_t, size_t> GetActionTensors(const std::shared_ptr<device::SwapActi
std::map<size_t, size_t> tensor_indexes;
std::set<size_t> is_real_parameter;
for (const auto &tensor_action : swap_action->actions_) {
MS_EXCEPTION_IF_NULL(tensor_action);
if (tensor_action->tensor_id_ >= swap_strategy->tensor_infos_.size()) {
MS_LOG(EXCEPTION) << "Invalid tensor id " << tensor_action->tensor_id_;
}
const auto &tensor_info = swap_strategy->tensor_infos_[tensor_action->tensor_id_];
MS_EXCEPTION_IF_NULL(tensor_info);
std::vector<size_t> real_tensor_ids;
if (tensor_info->fused_tensor_ids_.empty()) {
real_tensor_ids.emplace_back(tensor_info->tensor_id_);
@ -73,6 +76,7 @@ std::map<size_t, size_t> GetActionTensors(const std::shared_ptr<device::SwapActi
MS_LOG(EXCEPTION) << "Invalid tensor id " << real_tensor_id;
}
const auto &real_tensor_info = swap_strategy->tensor_infos_[real_tensor_id];
MS_EXCEPTION_IF_NULL(real_tensor_info);
const auto &node = real_tensor_info->node_;
const auto &real_parameter_iter = real_parameters.find(node);
if (real_parameter_iter == real_parameters.end()) {
@ -105,14 +109,16 @@ void GenActionIndexList(const std::map<size_t, size_t> &tensors_id_index_map,
std::map<device::SwapActionType, vector<size_t>> move_action_map;
const auto &actions = swap_action->actions_;
for (const auto &tensor_action : actions) {
MS_EXCEPTION_IF_NULL(tensor_action);
if (tensor_action->tensor_id_ >= swap_strategy->tensor_infos_.size()) {
MS_LOG(EXCEPTION) << "Invalid tensor id " << tensor_action->tensor_id_;
}
const auto &tensor_info = swap_strategy->tensor_infos_[tensor_action->tensor_id_];
MS_EXCEPTION_IF_NULL(tensor_info);
if (tensor_action->action_ == device::SwapActionType::kAllocHBM) {
std::vector<size_t> indexes;
std::transform(tensor_info->fused_tensor_ids_.begin(), tensor_info->fused_tensor_ids_.end(),
std::back_inserter(indexes), [](size_t index) { return index; });
std::copy(tensor_info->fused_tensor_ids_.begin(), tensor_info->fused_tensor_ids_.end(),
std::back_inserter(indexes));
alloc_action_map.emplace_back(indexes);
} else if (tensor_action->action_ != device::SwapActionType::kUnDefined) {
const auto tensor_id = tensor_info->tensor_id_;
@ -205,7 +211,7 @@ std::vector<std::vector<MemSwapActorPtr>> MemSwapScheduler::Build(const GraphCom
const auto device_context = graph_compiler_info.device_contexts_[i];
const auto &graph = graph_compiler_info.graphs_[i];
std::vector<MemSwapActorPtr> actors;
if (device_context == nullptr || graph->is_dynamic_shape()) {
if (device_context == nullptr || graph == nullptr || graph->is_dynamic_shape()) {
swap_actors.emplace_back(actors);
continue;
}
@ -228,7 +234,7 @@ AbstractActor *MemSwapScheduler::GetActorForLink(size_t id, const std::shared_pt
if (ret == nullptr) {
ret = actor_set->data_prepare_actor_.get();
}
} else if (id == graph->execution_order().size()) {
} else if (id == graph->execution_order().size() + kSecondVirtualNodeOffset) {
ret = dynamic_cast<ExitActor *>(GetCtrlActor(parser, graph.get(), kExitActorNameSuffix));
if (ret == nullptr) {
ret = actor_set->loop_count_actor_.get();
@ -258,7 +264,9 @@ void MemSwapScheduler::Link(const GraphCompilerInfo &graph_compiler_info, ActorS
continue;
}
const auto &strategy = strategy_iter->second;
MS_EXCEPTION_IF_NULL(strategy);
for (const auto &link : strategy->links_) {
MS_EXCEPTION_IF_NULL(link);
const auto from_actor =
GetActorForLink(link->from_, strategy, graph, graph_compiler_info.control_node_parser_, actor_set);
MS_EXCEPTION_IF_NULL(from_actor);

View File

@ -23,6 +23,7 @@
#include <map>
#include "runtime/hardware/device_type.h"
#include "runtime/device/device_address.h"
#include "runtime/device/gsm/swap_manager.h"
#include "runtime/collective/collective_communication_lib.h"
#include "runtime/collective/collective_comm_lib_loader.h"
#include "backend/common/session/kernel_graph.h"
@ -160,6 +161,8 @@ class BACKEND_EXPORT DeviceResManager {
// Return collective communication object for caller to access
CollectiveCommunicationLib *collective_comm_lib() const { return collective_comm_lib_; }
std::shared_ptr<SwapManager> swap_manager() const { return swap_manager_; }
protected:
// Ensure the thread safety for allocating device memory.
mutable std::mutex alloc_mem_mutex_;
@ -169,6 +172,8 @@ class BACKEND_EXPORT DeviceResManager {
DeviceContext *device_context_{nullptr};
std::shared_ptr<SwapManager> swap_manager_{nullptr};
private:
template <class... Args>
friend class DeviceInterface;