Add embedding cache scheduler
This commit is contained in:
parent
2ef534e903
commit
0b27a636be
|
@ -123,6 +123,46 @@ int64_t GetCacheOpsServiceId(const std::string &cache_operation, int32_t param_k
|
|||
int64_t id = SizeToLong(distributed::kEmbeddingCacheOps.size()) * IntToLong(param_key) + iter->second;
|
||||
return id;
|
||||
}
|
||||
|
||||
// Async copy host memory to device.
|
||||
bool MemcpyHostToDeviceAsync(void *dst, const void *src, size_t size, const DeviceContext *device_context,
|
||||
size_t stream_id) {
|
||||
MS_ERROR_IF_NULL(dst);
|
||||
MS_ERROR_IF_NULL(src);
|
||||
MS_ERROR_IF_NULL(device_context);
|
||||
|
||||
void *device_ptr = dst;
|
||||
const void *host_ptr = src;
|
||||
|
||||
auto device_address =
|
||||
device_context->CreateDeviceAddress(device_ptr, size, kOpFormat_DEFAULT, kTypeUnknown, ShapeVector());
|
||||
MS_ERROR_IF_NULL(device_address);
|
||||
RETURN_IF_FALSE_WITH_LOG(
|
||||
device_address->AsyncHostToDevice({}, size, kTypeUnknown, host_ptr, device_context->GetStream(stream_id)),
|
||||
"Async memcpy host to device failed.");
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Async copy device memory to host.
|
||||
bool MemcpyDeviceToHostAsync(void *dst, const void *src, size_t size, const DeviceContext *device_context,
|
||||
size_t stream_id) {
|
||||
MS_ERROR_IF_NULL(dst);
|
||||
MS_ERROR_IF_NULL(src);
|
||||
MS_ERROR_IF_NULL(device_context);
|
||||
|
||||
void *device_ptr = const_cast<void *>(src);
|
||||
void *host_ptr = dst;
|
||||
|
||||
auto device_address =
|
||||
device_context->CreateDeviceAddress(device_ptr, size, kOpFormat_DEFAULT, kTypeUnknown, ShapeVector());
|
||||
MS_ERROR_IF_NULL(device_address);
|
||||
RETURN_IF_FALSE_WITH_LOG(
|
||||
device_address->AsyncDeviceToHost({}, size, kTypeUnknown, host_ptr, device_context->GetStream(stream_id)),
|
||||
"Async memcpy device to host failed.");
|
||||
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void EmbeddingCachePrefetchActor::Initialize() {
|
||||
|
@ -130,6 +170,12 @@ void EmbeddingCachePrefetchActor::Initialize() {
|
|||
if (!device_context_->CreateStream(&stream_id_)) {
|
||||
MS_LOG(EXCEPTION) << "Create stream failed.";
|
||||
}
|
||||
|
||||
BuildEmbeddingCacheLookupKernel();
|
||||
BuildEmbeddingCacheUpdateKernel();
|
||||
|
||||
BuildRpcOperators();
|
||||
LinkRpcOperators();
|
||||
}
|
||||
|
||||
void EmbeddingCachePrefetchActor::Finalize() {
|
||||
|
@ -694,10 +740,21 @@ bool EmbeddingCachePrefetchActor::PushCacheFromDeviceToLocalHost(const HashTable
|
|||
auto embedding_size = hash_info.embedding_size;
|
||||
auto swap_out_data = std::make_unique<float[]>(swap_indices_size * embedding_size);
|
||||
|
||||
RETURN_IF_FALSE_WITH_LOG(
|
||||
MemcpyHostToDeviceAsync(embedding_device_cache_->hash_swap_index_addr_, device_cache_device_to_host_index,
|
||||
swap_indices_size * sizeof(int), device_context_, stream_id_),
|
||||
"Memcpy host to device asynchronously failed.");
|
||||
|
||||
RETURN_IF_FALSE_WITH_LOG(
|
||||
LookupDeviceCache(embedding_device_cache_->hash_swap_index_addr_, hash_table_addr, swap_indices_size,
|
||||
cache_vocab_size, embedding_size, embedding_device_cache_->hash_swap_value_addr_),
|
||||
"Lookup device cache failed.");
|
||||
|
||||
RETURN_IF_FALSE_WITH_LOG(
|
||||
MemcpyDeviceToHostAsync(swap_out_data.get(), embedding_device_cache_->hash_swap_value_addr_,
|
||||
swap_indices_size * embedding_size * sizeof(float), device_context_, stream_id_),
|
||||
"Memcpy device to host asynchronously failed.");
|
||||
|
||||
MS_ERROR_IF_NULL(device_context_);
|
||||
RETURN_IF_FALSE_WITH_LOG(device_context_->SyncStream(stream_id_), "Synchronize stream failed.");
|
||||
RETURN_IF_FALSE_WITH_LOG(
|
||||
|
@ -759,6 +816,15 @@ bool EmbeddingCachePrefetchActor::PullCacheFromLocalHostToDevice(const HashTable
|
|||
host_cache_host_to_device_index, swap_out_data.get()),
|
||||
"Lookup local host cache failed.");
|
||||
|
||||
RETURN_IF_FALSE_WITH_LOG(
|
||||
MemcpyHostToDeviceAsync(embedding_device_cache_->hash_swap_value_addr_, swap_out_data.get(),
|
||||
swap_indices_size * embedding_size * sizeof(float), device_context_, stream_id_),
|
||||
"Memcpy host to device asynchronously failed.");
|
||||
RETURN_IF_FALSE_WITH_LOG(
|
||||
MemcpyHostToDeviceAsync(embedding_device_cache_->hash_swap_index_addr_, device_cache_host_to_device_index,
|
||||
swap_indices_size * sizeof(int), device_context_, stream_id_),
|
||||
"Memcpy host to device asynchronously failed.");
|
||||
|
||||
RETURN_IF_FALSE_WITH_LOG(
|
||||
UpdateDeviceCache(embedding_device_cache_->hash_swap_index_addr_, embedding_device_cache_->hash_swap_value_addr_,
|
||||
swap_indices_size, cache_vocab_size, embedding_size, hash_table_addr),
|
||||
|
@ -1058,6 +1124,8 @@ bool EmbeddingCachePrefetchActor::PartitionIdsAndEmbeddings(const int *ids, size
|
|||
bool EmbeddingCachePrefetchActor::SendToRemote(const std::string &cache_operation, int32_t param_key,
|
||||
size_t server_rank_id, size_t embedding_dim, const void *keys,
|
||||
size_t keys_len, const void *values, size_t values_len) {
|
||||
MS_ERROR_IF_NULL(keys);
|
||||
MS_ERROR_IF_NULL(values);
|
||||
// Find sender corresponding to cache operation and parameter key.
|
||||
auto iter = rpc_operators_.find(cache_operation);
|
||||
if (iter == rpc_operators_.end()) {
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/graph_scheduler/embedding_cache_scheduler.h"
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
void EmbeddingCacheScheduler::Initialize() {
|
||||
if (!ps::PSContext::instance()->cache_enable() || !distributed::cluster::ClusterContext::instance()->initialized() ||
|
||||
!ps::PSContext::instance()->is_worker()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Get or Create device context.
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
std::string device_name = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
const auto &device_context =
|
||||
device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name, device_id});
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
device_context->Initialize();
|
||||
|
||||
// Create and initialize EmbeddingCachePrefetchActor.
|
||||
embedding_cache_prefetch_actor_ = std::make_shared<EmbeddingCachePrefetchActor>(device_context);
|
||||
MS_EXCEPTION_IF_NULL(embedding_cache_prefetch_actor_);
|
||||
|
||||
initialized_ = true;
|
||||
}
|
||||
|
||||
void EmbeddingCacheScheduler::Schedule() const {
|
||||
if (!initialized_) {
|
||||
return;
|
||||
}
|
||||
|
||||
// 1. Initialize embedding cache prefetch actor and build network connection inter process.
|
||||
MS_EXCEPTION_IF_NULL(embedding_cache_prefetch_actor_);
|
||||
embedding_cache_prefetch_actor_->Initialize();
|
||||
|
||||
// 2. Spawn embedding cache prefetch actor.
|
||||
auto actor_manager = ActorMgr::GetActorMgrRef();
|
||||
MS_EXCEPTION_IF_NULL(actor_manager);
|
||||
// Bind single thread to execute embedding cache prefetch actor.
|
||||
(void)actor_manager->Spawn(embedding_cache_prefetch_actor_, false);
|
||||
|
||||
// 3. Run embedding cache prefetch actor.
|
||||
ActorDispatcher::Send(embedding_cache_prefetch_actor_->GetAID(), &EmbeddingCachePrefetchActor::Run);
|
||||
}
|
||||
|
||||
void EmbeddingCacheScheduler::Finalize() {
|
||||
if (!initialized_) {
|
||||
return;
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(embedding_cache_prefetch_actor_);
|
||||
// Stop the embedding cache prefetch_actor.
|
||||
embedding_cache_prefetch_actor_->Finalize();
|
||||
|
||||
initialized_ = false;
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
|
@ -34,17 +34,24 @@ class EmbeddingCacheScheduler {
|
|||
// Build and initialize embedding cache prefetch actor and save it by embedding_cache_prefetch_actor_.
|
||||
void Initialize();
|
||||
|
||||
// Schedule, Initialize and Run embedding cache prefetch actor.
|
||||
// 1. Build network connection between local and remote cache for embedding cache prefetch actor.
|
||||
// 2. Schedule and Run embedding cache prefetch actor.
|
||||
// Since the embedding cache prefetch actor is spinning, and the actor is not in the actor set, start the actor in the
|
||||
// Schedule interface.
|
||||
void Schedule() const;
|
||||
|
||||
// Synchronize latest embedding table in local cache to remote.
|
||||
void SyncEmbeddingTable() const;
|
||||
|
||||
// Finalize embedding cache prefetch actor.
|
||||
void Finalize();
|
||||
|
||||
private:
|
||||
// Embedding cache prefetch actor.
|
||||
EmbeddingCachePrefetchActorPtr embedding_cache_prefetch_actor_;
|
||||
|
||||
// The flag indicates whether the EmbeddingCacheScheduler is initialized.
|
||||
bool initialized_{false};
|
||||
};
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -233,6 +233,7 @@ list(REMOVE_ITEM MINDSPORE_SRC_LIST
|
|||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/plugin/device/cpu/kernel/akg/akg_cpu_kernel_mod.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/plugin/device/cpu/kernel/akg/akg_cpu_kernel_build.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/runtime/graph_scheduler/rpc_node_scheduler.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/runtime/graph_scheduler/embedding_cache_scheduler.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST
|
||||
"../../../mindspore/ccsrc/runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.cc")
|
||||
|
||||
|
|
Loading…
Reference in New Issue