From 0b27a636be9f0960e01962e40db463b5d4ea0e7e Mon Sep 17 00:00:00 2001 From: lizhenyu Date: Sun, 29 May 2022 17:52:09 +0800 Subject: [PATCH] Add embedding cache scheduler --- .../embedding_cache_prefetch_actor.cc | 68 ++++++++++++++++ .../embedding_cache_scheduler.cc | 77 +++++++++++++++++++ .../embedding_cache_scheduler.h | 9 ++- tests/ut/cpp/CMakeLists.txt | 1 + 4 files changed, 154 insertions(+), 1 deletion(-) create mode 100644 mindspore/ccsrc/runtime/graph_scheduler/embedding_cache_scheduler.cc diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.cc b/mindspore/ccsrc/runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.cc index a8adedae9d5..3119f04e8a2 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.cc @@ -123,6 +123,46 @@ int64_t GetCacheOpsServiceId(const std::string &cache_operation, int32_t param_k int64_t id = SizeToLong(distributed::kEmbeddingCacheOps.size()) * IntToLong(param_key) + iter->second; return id; } + +// Async copy host memory to device. +bool MemcpyHostToDeviceAsync(void *dst, const void *src, size_t size, const DeviceContext *device_context, + size_t stream_id) { + MS_ERROR_IF_NULL(dst); + MS_ERROR_IF_NULL(src); + MS_ERROR_IF_NULL(device_context); + + void *device_ptr = dst; + const void *host_ptr = src; + + auto device_address = + device_context->CreateDeviceAddress(device_ptr, size, kOpFormat_DEFAULT, kTypeUnknown, ShapeVector()); + MS_ERROR_IF_NULL(device_address); + RETURN_IF_FALSE_WITH_LOG( + device_address->AsyncHostToDevice({}, size, kTypeUnknown, host_ptr, device_context->GetStream(stream_id)), + "Async memcpy host to device failed."); + + return true; +} + +// Async copy device memory to host. +bool MemcpyDeviceToHostAsync(void *dst, const void *src, size_t size, const DeviceContext *device_context, + size_t stream_id) { + MS_ERROR_IF_NULL(dst); + MS_ERROR_IF_NULL(src); + MS_ERROR_IF_NULL(device_context); + + void *device_ptr = const_cast(src); + void *host_ptr = dst; + + auto device_address = + device_context->CreateDeviceAddress(device_ptr, size, kOpFormat_DEFAULT, kTypeUnknown, ShapeVector()); + MS_ERROR_IF_NULL(device_address); + RETURN_IF_FALSE_WITH_LOG( + device_address->AsyncDeviceToHost({}, size, kTypeUnknown, host_ptr, device_context->GetStream(stream_id)), + "Async memcpy device to host failed."); + + return true; +} } // namespace void EmbeddingCachePrefetchActor::Initialize() { @@ -130,6 +170,12 @@ void EmbeddingCachePrefetchActor::Initialize() { if (!device_context_->CreateStream(&stream_id_)) { MS_LOG(EXCEPTION) << "Create stream failed."; } + + BuildEmbeddingCacheLookupKernel(); + BuildEmbeddingCacheUpdateKernel(); + + BuildRpcOperators(); + LinkRpcOperators(); } void EmbeddingCachePrefetchActor::Finalize() { @@ -694,10 +740,21 @@ bool EmbeddingCachePrefetchActor::PushCacheFromDeviceToLocalHost(const HashTable auto embedding_size = hash_info.embedding_size; auto swap_out_data = std::make_unique(swap_indices_size * embedding_size); + RETURN_IF_FALSE_WITH_LOG( + MemcpyHostToDeviceAsync(embedding_device_cache_->hash_swap_index_addr_, device_cache_device_to_host_index, + swap_indices_size * sizeof(int), device_context_, stream_id_), + "Memcpy host to device asynchronously failed."); + RETURN_IF_FALSE_WITH_LOG( LookupDeviceCache(embedding_device_cache_->hash_swap_index_addr_, hash_table_addr, swap_indices_size, cache_vocab_size, embedding_size, embedding_device_cache_->hash_swap_value_addr_), "Lookup device cache failed."); + + RETURN_IF_FALSE_WITH_LOG( + MemcpyDeviceToHostAsync(swap_out_data.get(), embedding_device_cache_->hash_swap_value_addr_, + swap_indices_size * embedding_size * sizeof(float), device_context_, stream_id_), + "Memcpy device to host asynchronously failed."); + MS_ERROR_IF_NULL(device_context_); RETURN_IF_FALSE_WITH_LOG(device_context_->SyncStream(stream_id_), "Synchronize stream failed."); RETURN_IF_FALSE_WITH_LOG( @@ -759,6 +816,15 @@ bool EmbeddingCachePrefetchActor::PullCacheFromLocalHostToDevice(const HashTable host_cache_host_to_device_index, swap_out_data.get()), "Lookup local host cache failed."); + RETURN_IF_FALSE_WITH_LOG( + MemcpyHostToDeviceAsync(embedding_device_cache_->hash_swap_value_addr_, swap_out_data.get(), + swap_indices_size * embedding_size * sizeof(float), device_context_, stream_id_), + "Memcpy host to device asynchronously failed."); + RETURN_IF_FALSE_WITH_LOG( + MemcpyHostToDeviceAsync(embedding_device_cache_->hash_swap_index_addr_, device_cache_host_to_device_index, + swap_indices_size * sizeof(int), device_context_, stream_id_), + "Memcpy host to device asynchronously failed."); + RETURN_IF_FALSE_WITH_LOG( UpdateDeviceCache(embedding_device_cache_->hash_swap_index_addr_, embedding_device_cache_->hash_swap_value_addr_, swap_indices_size, cache_vocab_size, embedding_size, hash_table_addr), @@ -1058,6 +1124,8 @@ bool EmbeddingCachePrefetchActor::PartitionIdsAndEmbeddings(const int *ids, size bool EmbeddingCachePrefetchActor::SendToRemote(const std::string &cache_operation, int32_t param_key, size_t server_rank_id, size_t embedding_dim, const void *keys, size_t keys_len, const void *values, size_t values_len) { + MS_ERROR_IF_NULL(keys); + MS_ERROR_IF_NULL(values); // Find sender corresponding to cache operation and parameter key. auto iter = rpc_operators_.find(cache_operation); if (iter == rpc_operators_.end()) { diff --git a/mindspore/ccsrc/runtime/graph_scheduler/embedding_cache_scheduler.cc b/mindspore/ccsrc/runtime/graph_scheduler/embedding_cache_scheduler.cc new file mode 100644 index 00000000000..437535d649b --- /dev/null +++ b/mindspore/ccsrc/runtime/graph_scheduler/embedding_cache_scheduler.cc @@ -0,0 +1,77 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/graph_scheduler/embedding_cache_scheduler.h" +#include +#include + +namespace mindspore { +namespace runtime { +void EmbeddingCacheScheduler::Initialize() { + if (!ps::PSContext::instance()->cache_enable() || !distributed::cluster::ClusterContext::instance()->initialized() || + !ps::PSContext::instance()->is_worker()) { + return; + } + + // Get or Create device context. + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + std::string device_name = ms_context->get_param(MS_CTX_DEVICE_TARGET); + uint32_t device_id = ms_context->get_param(MS_CTX_DEVICE_ID); + const auto &device_context = + device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name, device_id}); + MS_EXCEPTION_IF_NULL(device_context); + device_context->Initialize(); + + // Create and initialize EmbeddingCachePrefetchActor. + embedding_cache_prefetch_actor_ = std::make_shared(device_context); + MS_EXCEPTION_IF_NULL(embedding_cache_prefetch_actor_); + + initialized_ = true; +} + +void EmbeddingCacheScheduler::Schedule() const { + if (!initialized_) { + return; + } + + // 1. Initialize embedding cache prefetch actor and build network connection inter process. + MS_EXCEPTION_IF_NULL(embedding_cache_prefetch_actor_); + embedding_cache_prefetch_actor_->Initialize(); + + // 2. Spawn embedding cache prefetch actor. + auto actor_manager = ActorMgr::GetActorMgrRef(); + MS_EXCEPTION_IF_NULL(actor_manager); + // Bind single thread to execute embedding cache prefetch actor. + (void)actor_manager->Spawn(embedding_cache_prefetch_actor_, false); + + // 3. Run embedding cache prefetch actor. + ActorDispatcher::Send(embedding_cache_prefetch_actor_->GetAID(), &EmbeddingCachePrefetchActor::Run); +} + +void EmbeddingCacheScheduler::Finalize() { + if (!initialized_) { + return; + } + + MS_EXCEPTION_IF_NULL(embedding_cache_prefetch_actor_); + // Stop the embedding cache prefetch_actor. + embedding_cache_prefetch_actor_->Finalize(); + + initialized_ = false; +} +} // namespace runtime +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/graph_scheduler/embedding_cache_scheduler.h b/mindspore/ccsrc/runtime/graph_scheduler/embedding_cache_scheduler.h index ee0205d84e7..722fc1c5b01 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/embedding_cache_scheduler.h +++ b/mindspore/ccsrc/runtime/graph_scheduler/embedding_cache_scheduler.h @@ -34,17 +34,24 @@ class EmbeddingCacheScheduler { // Build and initialize embedding cache prefetch actor and save it by embedding_cache_prefetch_actor_. void Initialize(); - // Schedule, Initialize and Run embedding cache prefetch actor. + // 1. Build network connection between local and remote cache for embedding cache prefetch actor. + // 2. Schedule and Run embedding cache prefetch actor. // Since the embedding cache prefetch actor is spinning, and the actor is not in the actor set, start the actor in the // Schedule interface. void Schedule() const; + // Synchronize latest embedding table in local cache to remote. + void SyncEmbeddingTable() const; + // Finalize embedding cache prefetch actor. void Finalize(); private: // Embedding cache prefetch actor. EmbeddingCachePrefetchActorPtr embedding_cache_prefetch_actor_; + + // The flag indicates whether the EmbeddingCacheScheduler is initialized. + bool initialized_{false}; }; } // namespace runtime } // namespace mindspore diff --git a/tests/ut/cpp/CMakeLists.txt b/tests/ut/cpp/CMakeLists.txt index f67f8752eb0..6272fd715f9 100644 --- a/tests/ut/cpp/CMakeLists.txt +++ b/tests/ut/cpp/CMakeLists.txt @@ -233,6 +233,7 @@ list(REMOVE_ITEM MINDSPORE_SRC_LIST list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/plugin/device/cpu/kernel/akg/akg_cpu_kernel_mod.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/plugin/device/cpu/kernel/akg/akg_cpu_kernel_build.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/runtime/graph_scheduler/rpc_node_scheduler.cc") +list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/runtime/graph_scheduler/embedding_cache_scheduler.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.cc")