Add embedding cache scheduler

This commit is contained in:
lizhenyu 2022-05-29 17:52:09 +08:00
parent 2ef534e903
commit 0b27a636be
4 changed files with 154 additions and 1 deletions

View File

@ -123,6 +123,46 @@ int64_t GetCacheOpsServiceId(const std::string &cache_operation, int32_t param_k
int64_t id = SizeToLong(distributed::kEmbeddingCacheOps.size()) * IntToLong(param_key) + iter->second;
return id;
}
// Async copy host memory to device.
bool MemcpyHostToDeviceAsync(void *dst, const void *src, size_t size, const DeviceContext *device_context,
size_t stream_id) {
MS_ERROR_IF_NULL(dst);
MS_ERROR_IF_NULL(src);
MS_ERROR_IF_NULL(device_context);
void *device_ptr = dst;
const void *host_ptr = src;
auto device_address =
device_context->CreateDeviceAddress(device_ptr, size, kOpFormat_DEFAULT, kTypeUnknown, ShapeVector());
MS_ERROR_IF_NULL(device_address);
RETURN_IF_FALSE_WITH_LOG(
device_address->AsyncHostToDevice({}, size, kTypeUnknown, host_ptr, device_context->GetStream(stream_id)),
"Async memcpy host to device failed.");
return true;
}
// Async copy device memory to host.
bool MemcpyDeviceToHostAsync(void *dst, const void *src, size_t size, const DeviceContext *device_context,
size_t stream_id) {
MS_ERROR_IF_NULL(dst);
MS_ERROR_IF_NULL(src);
MS_ERROR_IF_NULL(device_context);
void *device_ptr = const_cast<void *>(src);
void *host_ptr = dst;
auto device_address =
device_context->CreateDeviceAddress(device_ptr, size, kOpFormat_DEFAULT, kTypeUnknown, ShapeVector());
MS_ERROR_IF_NULL(device_address);
RETURN_IF_FALSE_WITH_LOG(
device_address->AsyncDeviceToHost({}, size, kTypeUnknown, host_ptr, device_context->GetStream(stream_id)),
"Async memcpy device to host failed.");
return true;
}
} // namespace
void EmbeddingCachePrefetchActor::Initialize() {
@ -130,6 +170,12 @@ void EmbeddingCachePrefetchActor::Initialize() {
if (!device_context_->CreateStream(&stream_id_)) {
MS_LOG(EXCEPTION) << "Create stream failed.";
}
BuildEmbeddingCacheLookupKernel();
BuildEmbeddingCacheUpdateKernel();
BuildRpcOperators();
LinkRpcOperators();
}
void EmbeddingCachePrefetchActor::Finalize() {
@ -694,10 +740,21 @@ bool EmbeddingCachePrefetchActor::PushCacheFromDeviceToLocalHost(const HashTable
auto embedding_size = hash_info.embedding_size;
auto swap_out_data = std::make_unique<float[]>(swap_indices_size * embedding_size);
RETURN_IF_FALSE_WITH_LOG(
MemcpyHostToDeviceAsync(embedding_device_cache_->hash_swap_index_addr_, device_cache_device_to_host_index,
swap_indices_size * sizeof(int), device_context_, stream_id_),
"Memcpy host to device asynchronously failed.");
RETURN_IF_FALSE_WITH_LOG(
LookupDeviceCache(embedding_device_cache_->hash_swap_index_addr_, hash_table_addr, swap_indices_size,
cache_vocab_size, embedding_size, embedding_device_cache_->hash_swap_value_addr_),
"Lookup device cache failed.");
RETURN_IF_FALSE_WITH_LOG(
MemcpyDeviceToHostAsync(swap_out_data.get(), embedding_device_cache_->hash_swap_value_addr_,
swap_indices_size * embedding_size * sizeof(float), device_context_, stream_id_),
"Memcpy device to host asynchronously failed.");
MS_ERROR_IF_NULL(device_context_);
RETURN_IF_FALSE_WITH_LOG(device_context_->SyncStream(stream_id_), "Synchronize stream failed.");
RETURN_IF_FALSE_WITH_LOG(
@ -759,6 +816,15 @@ bool EmbeddingCachePrefetchActor::PullCacheFromLocalHostToDevice(const HashTable
host_cache_host_to_device_index, swap_out_data.get()),
"Lookup local host cache failed.");
RETURN_IF_FALSE_WITH_LOG(
MemcpyHostToDeviceAsync(embedding_device_cache_->hash_swap_value_addr_, swap_out_data.get(),
swap_indices_size * embedding_size * sizeof(float), device_context_, stream_id_),
"Memcpy host to device asynchronously failed.");
RETURN_IF_FALSE_WITH_LOG(
MemcpyHostToDeviceAsync(embedding_device_cache_->hash_swap_index_addr_, device_cache_host_to_device_index,
swap_indices_size * sizeof(int), device_context_, stream_id_),
"Memcpy host to device asynchronously failed.");
RETURN_IF_FALSE_WITH_LOG(
UpdateDeviceCache(embedding_device_cache_->hash_swap_index_addr_, embedding_device_cache_->hash_swap_value_addr_,
swap_indices_size, cache_vocab_size, embedding_size, hash_table_addr),
@ -1058,6 +1124,8 @@ bool EmbeddingCachePrefetchActor::PartitionIdsAndEmbeddings(const int *ids, size
bool EmbeddingCachePrefetchActor::SendToRemote(const std::string &cache_operation, int32_t param_key,
size_t server_rank_id, size_t embedding_dim, const void *keys,
size_t keys_len, const void *values, size_t values_len) {
MS_ERROR_IF_NULL(keys);
MS_ERROR_IF_NULL(values);
// Find sender corresponding to cache operation and parameter key.
auto iter = rpc_operators_.find(cache_operation);
if (iter == rpc_operators_.end()) {

View File

@ -0,0 +1,77 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/graph_scheduler/embedding_cache_scheduler.h"
#include <string>
#include <memory>
namespace mindspore {
namespace runtime {
void EmbeddingCacheScheduler::Initialize() {
if (!ps::PSContext::instance()->cache_enable() || !distributed::cluster::ClusterContext::instance()->initialized() ||
!ps::PSContext::instance()->is_worker()) {
return;
}
// Get or Create device context.
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
std::string device_name = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
const auto &device_context =
device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name, device_id});
MS_EXCEPTION_IF_NULL(device_context);
device_context->Initialize();
// Create and initialize EmbeddingCachePrefetchActor.
embedding_cache_prefetch_actor_ = std::make_shared<EmbeddingCachePrefetchActor>(device_context);
MS_EXCEPTION_IF_NULL(embedding_cache_prefetch_actor_);
initialized_ = true;
}
void EmbeddingCacheScheduler::Schedule() const {
if (!initialized_) {
return;
}
// 1. Initialize embedding cache prefetch actor and build network connection inter process.
MS_EXCEPTION_IF_NULL(embedding_cache_prefetch_actor_);
embedding_cache_prefetch_actor_->Initialize();
// 2. Spawn embedding cache prefetch actor.
auto actor_manager = ActorMgr::GetActorMgrRef();
MS_EXCEPTION_IF_NULL(actor_manager);
// Bind single thread to execute embedding cache prefetch actor.
(void)actor_manager->Spawn(embedding_cache_prefetch_actor_, false);
// 3. Run embedding cache prefetch actor.
ActorDispatcher::Send(embedding_cache_prefetch_actor_->GetAID(), &EmbeddingCachePrefetchActor::Run);
}
void EmbeddingCacheScheduler::Finalize() {
if (!initialized_) {
return;
}
MS_EXCEPTION_IF_NULL(embedding_cache_prefetch_actor_);
// Stop the embedding cache prefetch_actor.
embedding_cache_prefetch_actor_->Finalize();
initialized_ = false;
}
} // namespace runtime
} // namespace mindspore

View File

@ -34,17 +34,24 @@ class EmbeddingCacheScheduler {
// Build and initialize embedding cache prefetch actor and save it by embedding_cache_prefetch_actor_.
void Initialize();
// Schedule, Initialize and Run embedding cache prefetch actor.
// 1. Build network connection between local and remote cache for embedding cache prefetch actor.
// 2. Schedule and Run embedding cache prefetch actor.
// Since the embedding cache prefetch actor is spinning, and the actor is not in the actor set, start the actor in the
// Schedule interface.
void Schedule() const;
// Synchronize latest embedding table in local cache to remote.
void SyncEmbeddingTable() const;
// Finalize embedding cache prefetch actor.
void Finalize();
private:
// Embedding cache prefetch actor.
EmbeddingCachePrefetchActorPtr embedding_cache_prefetch_actor_;
// The flag indicates whether the EmbeddingCacheScheduler is initialized.
bool initialized_{false};
};
} // namespace runtime
} // namespace mindspore

View File

@ -233,6 +233,7 @@ list(REMOVE_ITEM MINDSPORE_SRC_LIST
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/plugin/device/cpu/kernel/akg/akg_cpu_kernel_mod.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/plugin/device/cpu/kernel/akg/akg_cpu_kernel_build.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/runtime/graph_scheduler/rpc_node_scheduler.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/runtime/graph_scheduler/embedding_cache_scheduler.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST
"../../../mindspore/ccsrc/runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.cc")