diff --git a/.jenkins/check/config/filter_cppcheck.txt b/.jenkins/check/config/filter_cppcheck.txt index 4610f5cc56d..67bddf0feb7 100644 --- a/.jenkins/check/config/filter_cppcheck.txt +++ b/.jenkins/check/config/filter_cppcheck.txt @@ -28,6 +28,7 @@ "mindspore/mindspore/core/load_mindir/anf_model_parser.cc" "stlIfStrFind" "mindspore/mindspore/ccsrc/frontend/optimizer/ad/kpynative.cc" "containerOutOfBounds" "mindspore/mindspore/ccsrc/pipeline/jit/action.cc" "unreadVariable" +"mindspore/mindspore/ccsrc/plugin/device/gpu/hal/hardware/gpu_device_context.c" "unreadVariable" # MindData "mindspore/mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.cc" "useStlAlgorithm" diff --git a/mindspore/ccsrc/distributed/rpc/tcp/connection.cc b/mindspore/ccsrc/distributed/rpc/tcp/connection.cc index 3a882dd7cbe..8e6eef2eff7 100644 --- a/mindspore/ccsrc/distributed/rpc/tcp/connection.cc +++ b/mindspore/ccsrc/distributed/rpc/tcp/connection.cc @@ -513,6 +513,10 @@ bool Connection::ParseMessage() { total_recv_len -= recvLen; return false; } + if (!SetUrlForRecvMessage()) { + MS_LOG(ERROR) << "Set url info for recv message failed."; + return false; + } recv_state = State::kMsgHeader; break; default: @@ -521,6 +525,24 @@ bool Connection::ParseMessage() { return true; } +bool Connection::SetUrlForRecvMessage() { + auto recv_from_separator_pos = recv_from.find('@'); + auto recv_to_separator_pos = recv_to.find('@'); + if (recv_from_separator_pos == std::string::npos && recv_to_separator_pos == std::string::npos) { + MS_LOG(ERROR) << "Invalid message format, can not find separator '@'"; + return false; + } + + std::string from_name = recv_from.substr(0, recv_from_separator_pos); + std::string from_url = recv_from.substr(recv_from_separator_pos + 1); + std::string to_name = recv_to.substr(0, recv_to_separator_pos); + std::string to_url = recv_to.substr(recv_to_separator_pos + 1); + recv_message->from = AID(from_name, from_url); + recv_message->to = AID(to_name, to_url); + + return true; +} + void Connection::ReorderHeader(MessageHeader *header) const { header->name_len = ntohl(header->name_len); header->to_len = ntohl(header->to_len); diff --git a/mindspore/ccsrc/distributed/rpc/tcp/connection.h b/mindspore/ccsrc/distributed/rpc/tcp/connection.h index 8f5265cc801..e85a9bbe297 100644 --- a/mindspore/ccsrc/distributed/rpc/tcp/connection.h +++ b/mindspore/ccsrc/distributed/rpc/tcp/connection.h @@ -204,6 +204,9 @@ struct Connection { // Parse message from socket recv buffer. bool ParseMessage(); + // After ParseMessage, set from url and to url into recv message. + bool SetUrlForRecvMessage(); + // Make a http message based on given input message. std::string GenerateHttpMessage(MessageBase *msg); diff --git a/mindspore/ccsrc/frontend/parallel/cache_embedding/ps_embedding_cache_inserter.cc b/mindspore/ccsrc/frontend/parallel/cache_embedding/ps_embedding_cache_inserter.cc index 3c2b6767bb5..41880217ae2 100644 --- a/mindspore/ccsrc/frontend/parallel/cache_embedding/ps_embedding_cache_inserter.cc +++ b/mindspore/ccsrc/frontend/parallel/cache_embedding/ps_embedding_cache_inserter.cc @@ -65,6 +65,26 @@ void PsEmbeddingCacheInserter::GetEmbeddingLookupNodes() { }); } +void PsEmbeddingCacheInserter::GetCacheEnableParameters() { + MS_EXCEPTION_IF_NULL(root_graph_); + const std::vector ¶meters = root_graph_->parameters(); + auto params_size = parameters.size(); + for (size_t i = 0; i < params_size; ++i) { + MS_EXCEPTION_IF_NULL(parameters[i]); + if (!parameters[i]->isa()) { + MS_LOG(EXCEPTION) << "The node with name: " << parameters[i]->fullname_with_scope() << "is not a Parameter."; + } + + ParameterPtr param = parameters[i]->cast(); + MS_EXCEPTION_IF_NULL(param); + auto param_info = param->param_info(); + if (param_info && param_info->key() != -1 && param_info->cache_enable()) { + keys_to_params_[param_info->key()] = param; + MS_LOG(INFO) << "Parameter[" << param->fullname_with_scope() << "], key[" << param_info->key() << "]"; + } + } +} + void PsEmbeddingCacheInserter::SetNodeAttr(const CNodePtr &node, const std::string &node_role) const { MS_EXCEPTION_IF_NULL(node); @@ -107,7 +127,7 @@ void PsEmbeddingCacheInserter::SetSendNodeAttr(const CNodePtr &send_node, int32_ dst_ranks.push_back(i); dst_roles.push_back(dst_role); // Unique edge name: src role + src rank id -> dst role + dst rank id +embedding cache operation + parameter key. - inter_process_edges.push_back(distributed::kEnvRoleOfServer + std::to_string(rank_id_) + "->" + dst_role + + inter_process_edges.push_back(distributed::kEnvRoleOfPServer + std::to_string(rank_id_) + "->" + dst_role + std::to_string(i) + "_" + embedding_cache_op + "_" + distributed::kParameterKey + std::to_string(param_key)); } @@ -118,6 +138,7 @@ void PsEmbeddingCacheInserter::SetSendNodeAttr(const CNodePtr &send_node, int32_ common::AnfAlgo::SetNodeAttr(kAttrSendDstNodeName, MakeValue(std::string(kEmbeddingLocalCacheNode)), send_node); common::AnfAlgo::SetNodeAttr(kAttrInterProcessEdgeNames, MakeValue(inter_process_edges), send_node); + common::AnfAlgo::SetNodeAttr(kAttrIsMuxRpcKernel, MakeValue(true), send_node); } void PsEmbeddingCacheInserter::SetRecvNodeAttr(const CNodePtr &recv_node, const std::string &src_role) const { @@ -139,7 +160,7 @@ void PsEmbeddingCacheInserter::SetRecvNodeAttr(const CNodePtr &recv_node, const src_roles.push_back(src_role); // Unique edge name: src role + src rank id -> dst role + dst rank id + embedding cache operation + parameter // key. - inter_process_edges.push_back(src_role + std::to_string(i) + "->" + distributed::kEnvRoleOfServer + + inter_process_edges.push_back(src_role + std::to_string(i) + "->" + distributed::kEnvRoleOfPServer + std::to_string(rank_id_) + "_" + distributed::kEmbeddingCacheOps[k] + "_" + distributed::kParameterKey + std::to_string(param_key)); } @@ -150,7 +171,9 @@ void PsEmbeddingCacheInserter::SetRecvNodeAttr(const CNodePtr &recv_node, const common::AnfAlgo::SetNodeAttr(kAttrRecvSrcRoles, MakeValue(src_roles), recv_node); common::AnfAlgo::SetNodeAttr(kAttrRecvSrcNodeName, MakeValue(std::string(kEmbeddingLocalCacheNode)), recv_node); common::AnfAlgo::SetNodeAttr(kAttrRecvDstNodeName, MakeValue(std::string(kEmbeddingRemoteCacheNode)), recv_node); + common::AnfAlgo::SetNodeAttr(kAttrInterProcessEdgeNames, MakeValue(inter_process_edges), recv_node); + common::AnfAlgo::SetNodeAttr(kAttrIsMuxRpcKernel, MakeValue(true), recv_node); } CNodePtr PsEmbeddingCacheInserter::CreateReturnNode(const FuncGraphPtr graph, const AnfNodePtr &output_node) const { @@ -421,6 +444,9 @@ bool PsEmbeddingCacheInserter::Run() { // Get EmbeddingLookup nodes which are executed on server from origin function graph. GetEmbeddingLookupNodes(); + // Get parameters enabled embedding cache of origin function graph. + GetCacheEnableParameters(); + // Construct the embedding cache graph of server. RETURN_IF_FALSE_WITH_LOG(ConstructEmbeddingCacheGraph(), "Construct embedding cache graph failed."); diff --git a/mindspore/ccsrc/frontend/parallel/cache_embedding/ps_embedding_cache_inserter.h b/mindspore/ccsrc/frontend/parallel/cache_embedding/ps_embedding_cache_inserter.h index 4170a2a9855..e60a16e3229 100644 --- a/mindspore/ccsrc/frontend/parallel/cache_embedding/ps_embedding_cache_inserter.h +++ b/mindspore/ccsrc/frontend/parallel/cache_embedding/ps_embedding_cache_inserter.h @@ -21,7 +21,6 @@ #include #include -#include "utils/hash_map.h" #include "ir/anf.h" #include "distributed/constants.h" @@ -104,7 +103,7 @@ class PsEmbeddingCacheInserter { // Record parameters enabled embedding cache of origin function graph. // Key: parameter key, Value: ParameterPtr - mindspore::HashMap keys_to_params_; + std::map keys_to_params_; // Record EmbeddingLookup nodes which are executed on server from origin function graph. // Key: shape of EmbeddingLookup node, Value: EmbeddingLookup AnfNodePtr. diff --git a/mindspore/ccsrc/include/common/utils/utils.h b/mindspore/ccsrc/include/common/utils/utils.h index 2187ff48503..56a53dfc089 100644 --- a/mindspore/ccsrc/include/common/utils/utils.h +++ b/mindspore/ccsrc/include/common/utils/utils.h @@ -580,6 +580,7 @@ constexpr auto kAttrRecvSrcRanks = "recv_src_ranks"; constexpr auto kAttrRecvSrcRoles = "recv_src_roles"; constexpr auto kAttrInterProcessEdgeNames = "inter_process_edge_names"; constexpr auto kAttrInterProcessEdgeLabel = "inter_process_edge_label"; +constexpr auto kAttrIsMuxRpcKernel = "is_mux_rpc_kernel"; constexpr auto kAttrForwardOpOutputId = "forward_op_output_id"; constexpr auto kAttrGroupRankIds = "group_rank_ids"; constexpr auto kAttrReuseCommunication = "reuse_communication_node"; diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index 3a9747ebead..d0c0e89fd46 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -1423,6 +1423,12 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc MS_EXCEPTION_IF_NULL(backend); // The data set graph compiling and running of mindRT. if (MsContext::GetInstance()->get_param(MS_CTX_ENABLE_MINDRT)) { +#ifdef WITH_BACKEND + if (ps::PSContext::instance()->is_worker() && ps::PSContext::instance()->cache_enable()) { + ps::PsDataPrefetch::GetInstance().CreateDataChannel(queue_name, LongToSize(size)); + } +#endif + const auto &mindrt_backend = std::dynamic_pointer_cast(backend); MS_EXCEPTION_IF_NULL(mindrt_backend); auto &actor_info = mindrt_backend->CompileGraphs(func_graph); diff --git a/mindspore/ccsrc/plugin/device/gpu/hal/hardware/gpu_device_context.cc b/mindspore/ccsrc/plugin/device/gpu/hal/hardware/gpu_device_context.cc index 566991813c1..3838080677a 100644 --- a/mindspore/ccsrc/plugin/device/gpu/hal/hardware/gpu_device_context.cc +++ b/mindspore/ccsrc/plugin/device/gpu/hal/hardware/gpu_device_context.cc @@ -361,6 +361,61 @@ void SetKernelInfoBeforeCreateKernel(const std::vector &nodes) { } } } + +// Check whether mutex exists for a stream. +std::pair CheckStreamMutexExist( + const void *stream, const mindspore::HashMap> &mtxs_for_streams, + std::shared_mutex *shd_mtx) { + MS_EXCEPTION_IF_NULL(stream); + MS_EXCEPTION_IF_NULL(shd_mtx); + std::shared_lock shd_lock(*shd_mtx); + auto iter = mtxs_for_streams.find(stream); + if (iter != mtxs_for_streams.end()) { + MS_EXCEPTION_IF_NULL(iter->second); + return std::make_pair(true, iter->second.get()); + } + return std::make_pair(false, nullptr); +} + +// Create a mutex for stream. +std::mutex *CreateStreamMutex(const void *stream, std::shared_mutex *shd_mtx, + mindspore::HashMap> *mtxs_for_streams) { + MS_EXCEPTION_IF_NULL(stream); + MS_EXCEPTION_IF_NULL(shd_mtx); + MS_EXCEPTION_IF_NULL(mtxs_for_streams); + + std::unique_lock unq_lock(*shd_mtx); + auto ret_pair = mtxs_for_streams->emplace(stream, std::make_shared()); + + MS_EXCEPTION_IF_NULL(ret_pair.first->second); + return ret_pair.first->second.get(); +} + +// The launch kernel is thread-unsafe, and the behavior of delivering the kernel launch to the same stream requires +// lock protection, need to create a separate lock for each stream. +// for GPU, The cublas handle is not thread safety specifically, it is not recommended that multiple threads access the +// same cublas handle at the same time, so need the launch mutex when multiple threads launch the cublas kernels. +std::lock_guard LockLaunchKernel(const void *stream) { + MS_EXCEPTION_IF_NULL(stream); + // Read-write lock for accessing mtxs_for_streams map. + // When the lock of each stream is created, mtxs_for_streams can be accessed concurrently to improve performance. + static std::shared_mutex shd_mtx; + static mindspore::HashMap> mtxs_for_streams; + + std::mutex *stream_mtx; + // Check whether mutex exists for a stream. + std::pair ret_pair = CheckStreamMutexExist(stream, mtxs_for_streams, &shd_mtx); + if (ret_pair.first) { + stream_mtx = ret_pair.second; + } else { + // Create a mutex for stream. + stream_mtx = CreateStreamMutex(stream, &shd_mtx, &mtxs_for_streams); + } + + MS_EXCEPTION_IF_NULL(stream_mtx); + // Lock kernel launch for the stream. + return std::lock_guard(*stream_mtx); +} } // namespace void GPUDeviceContext::OptimizeGraph(const FuncGraphPtr &graph) const { @@ -462,20 +517,23 @@ bool GPUDeviceContext::LaunchKernel(const CNodePtr &kernel, const std::vectorGetEnableFlag()) { #endif - std::lock_guard locker(launch_mutex_); + auto lock = LockLaunchKernel(stream); MS_LOG(DEBUG) << "Launch kernel: " << kernel->fullname_with_scope(); - ret = DoLaunchKernel(kernel, inputs, workspace, outputs); + ret = DoLaunchKernel(kernel, inputs, workspace, outputs, stream); #ifndef ENABLE_SECURITY } else { - std::lock_guard locker(launch_mutex_); + auto lock = LockLaunchKernel(stream); MS_LOG(DEBUG) << "Launch kernel: " << kernel->fullname_with_scope(); - ret = LaunchKernelWithProfiling(kernel, inputs, workspace, outputs); + ret = LaunchKernelWithProfiling(kernel, inputs, workspace, outputs, stream); } #endif if (!ret) { @@ -496,8 +554,9 @@ bool GPUDeviceContext::LaunchKernel(const CNodePtr &kernel, const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) const { + const std::vector &outputs, void *stream) const { MS_EXCEPTION_IF_NULL(kernel); + MS_EXCEPTION_IF_NULL(stream); auto kernel_graph = std::dynamic_pointer_cast(kernel->func_graph()); MS_EXCEPTION_IF_NULL(kernel_graph); @@ -512,7 +571,7 @@ bool GPUDeviceContext::LaunchKernelWithProfiling(const CNodePtr &kernel, const s } profiler_inst->OpDataProducerBegin(kernel->fullname_with_scope(), streams_.front()); - bool ret = DoLaunchKernel(kernel, inputs, workspace, outputs); + bool ret = DoLaunchKernel(kernel, inputs, workspace, outputs, stream); profiler_inst->OpDataProducerEnd(); profiler_inst->RecordFrameWorkInfo(kernel); @@ -527,8 +586,16 @@ bool GPUDeviceContext::LaunchKernelWithProfiling(const CNodePtr &kernel, const s } #endif bool GPUDeviceContext::DoLaunchKernel(const CNodePtr &kernel, const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs) const { + const std::vector &workspace, const std::vector &outputs, + void *stream) const { + MS_EXCEPTION_IF_NULL(kernel); + MS_EXCEPTION_IF_NULL(stream); + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + return kernel_mod->Launch(inputs, workspace, outputs, stream); +} + +void *GPUDeviceContext::GetLaunchKernelStream(const CNodePtr &kernel) const { void *stream = nullptr; if (common::AnfAlgo::HasNodeAttr(kAttrStream, kernel)) { auto stream_id = common::AnfAlgo::GetNodeAttr(kernel, kAttrStream); @@ -542,9 +609,7 @@ bool GPUDeviceContext::DoLaunchKernel(const CNodePtr &kernel, const std::vector< } MS_EXCEPTION_IF_NULL(stream); - auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); - return kernel_mod->Launch(inputs, workspace, outputs, stream); + return stream; } bool GPUDeviceContext::SyncStream(size_t stream_id) const { diff --git a/mindspore/ccsrc/plugin/device/gpu/hal/hardware/gpu_device_context.h b/mindspore/ccsrc/plugin/device/gpu/hal/hardware/gpu_device_context.h index d2883f2642d..8137463df1b 100644 --- a/mindspore/ccsrc/plugin/device/gpu/hal/hardware/gpu_device_context.h +++ b/mindspore/ccsrc/plugin/device/gpu/hal/hardware/gpu_device_context.h @@ -92,12 +92,17 @@ class GPUDeviceContext : public DeviceContext { #ifndef ENABLE_SECURITY // Launch a kernel and record the elapsed time end to end. bool LaunchKernelWithProfiling(const CNodePtr &kernel, const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs) const; + const std::vector &workspace, const std::vector &outputs, + void *stream) const; #endif // Launch a kernel by 'KernelMod' of the kernel. bool DoLaunchKernel(const CNodePtr &kernel, const std::vector &inputs, - const std::vector &workspace, const std::vector &outputs) const; + const std::vector &workspace, const std::vector &outputs, + void *stream) const; + + // Get the used to launch kernel, if there is a stream saved in attrs of kernel, use this stream, otherwise use + // default stream. + void *GetLaunchKernelStream(const CNodePtr &kernel) const; // Really create a cuda stream. bool CreateStream(void **stream) const override; @@ -105,10 +110,6 @@ class GPUDeviceContext : public DeviceContext { // Really destroy a cuda stream. bool DestroyStream(void *stream) const override; - // The cublas handle is not thread safety specifically, it is not recommended that multiple threads access the same - // cublas handle at the same time, so need the launch mutex when multiple threads launch the cublas kernels. - mutable std::mutex launch_mutex_; - std::shared_ptr mem_manager_; std::vector streams_; bool initialized_; diff --git a/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.cc b/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.cc index 02ff1211d99..71ee1a1b947 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.cc +++ b/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.cc @@ -22,6 +22,11 @@ namespace ps { const size_t kTimeoutLoopCount = 40; const int64_t kLongestTimeToWait = 30; +PsDataPrefetch &PsDataPrefetch::GetInstance() { + static PsDataPrefetch instance; + return instance; +} + void PsDataPrefetch::CreateDataChannel(const std::string &channel_name, size_t step_num) { if (cache_enable_ == false) { return; diff --git a/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.h b/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.h index 02d3fa78876..6249f3ea89e 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.h +++ b/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.h @@ -29,10 +29,7 @@ namespace mindspore { namespace ps { class EXPORT PsDataPrefetch { public: - EXPORT static PsDataPrefetch &GetInstance() { - static PsDataPrefetch instance; - return instance; - } + EXPORT static PsDataPrefetch &GetInstance(); EXPORT bool cache_enable() const { return cache_enable_; } EXPORT void set_cache_enable(bool cache_enable) { cache_enable_ = cache_enable; } diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/actor_common.cc b/mindspore/ccsrc/runtime/graph_scheduler/actor/actor_common.cc index 25123192fc0..00abf19321b 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/actor_common.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/actor_common.cc @@ -19,6 +19,7 @@ #include "runtime/graph_scheduler/device_tensor_store.h" #include "utils/ms_context.h" #include "include/common/utils/anfalgo.h" +#include "ps/ps_context.h" namespace mindspore { namespace runtime { @@ -391,5 +392,9 @@ std::set FetchModifiableRefOutputIndex(const CNodePtr &cnode, const Kern } return ref_output_indexes; } + +bool is_embedding_cache_server() { + return ps::PSContext::instance()->cache_enable() && ps::PSContext::instance()->is_server(); +} } // namespace runtime } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/actor_common.h b/mindspore/ccsrc/runtime/graph_scheduler/actor/actor_common.h index fefc21b5064..6be57a3eb92 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/actor_common.h +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/actor_common.h @@ -285,6 +285,9 @@ std::string FetchActorName(KernelTransformType kernel_type, const std::string &a std::set FetchModifiableRefInputIndex(const CNodePtr &node); // Fetch the output indexes which may be modified that exist in the ref node. std::set FetchModifiableRefOutputIndex(const CNodePtr &node, const KernelGraphPtr &graph); + +// Check whether this process is parameter server and enable embedding cache. +bool is_embedding_cache_server(); } // namespace runtime } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.cc b/mindspore/ccsrc/runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.cc index 7218ca26bb1..8969c9c1c0d 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.cc @@ -22,6 +22,7 @@ namespace mindspore { namespace runtime { +using distributed::cluster::ClusterContext; using mindspore::session::KernelGraph; // One and two dimensional shape placeholder. @@ -58,7 +59,7 @@ bool InferOpShape(const CNodePtr &kernel) { auto kernel_mod = AnfAlgo::GetKernelMod(kernel); MS_EXCEPTION_IF_NULL(kernel_mod); - if (!kernel_mod->Resize(args->op, args->inputs, args->outputs, args->depend_tensor_map)) { + if (kernel::KRET_OK != kernel_mod->Resize(args->op, args->inputs, args->outputs, args->depend_tensor_map)) { MS_LOG(ERROR) << "Kernel " << kernel->fullname_with_scope() << " resize failed."; return false; } @@ -88,14 +89,17 @@ SendRecvPair CreateSenderReceiverPair(uint32_t worker_rank, uint32_t server_rank int32_t param_key) { // Create sender and receiver pair. ReceiverPtr receiver = std::make_shared(); - SenderPtr sender = std::make_shared(receiver); + MS_EXCEPTION_IF_NULL(receiver); + SenderPtr sender = std::make_shared(); + MS_EXCEPTION_IF_NULL(sender); + sender->set_receiver(receiver); // Set inter process edge - receiver->set_inter_process_edge_name(GenerateInterProcessEdge(distributed::kEnvRoleOfServer, server_rank, + receiver->set_inter_process_edge_name(GenerateInterProcessEdge(distributed::kEnvRoleOfPServer, server_rank, distributed::kEnvRoleOfWorker, worker_rank, cache_operation, param_key)); sender->set_inter_process_edge_name(GenerateInterProcessEdge(distributed::kEnvRoleOfWorker, worker_rank, - distributed::kEnvRoleOfServer, server_rank, + distributed::kEnvRoleOfPServer, server_rank, distributed::kLookupEmbeddingCache, param_key)); // Set route table proxy. @@ -172,9 +176,24 @@ void EmbeddingCachePrefetchActor::Initialize() { MS_LOG(EXCEPTION) << "Create stream failed."; } + // Get embedding cache table info. + hash_tables_ = embedding_cache_table_manager.hash_tables_; + local_host_cache_size_ = embedding_cache_table_manager.host_cache_size_; + vocab_size_ = embedding_cache_table_manager.vocab_size_; + embedding_device_cache_ = embedding_cache_table_manager.embedding_device_cache_; + MS_EXCEPTION_IF_NULL(embedding_device_cache_); + embedding_host_cache_ = embedding_cache_table_manager.embedding_host_cache_; + MS_EXCEPTION_IF_NULL(embedding_host_cache_); + local_embedding_slice_bounds_ = embedding_cache_table_manager.local_embedding_slice_bounds_; + local_device_cache_bounds_ = embedding_cache_table_manager.local_device_cache_bounds_; + + // Get the id range of each server's embedding table slice. + GetRemoteEmbeddingSliceBound(); + BuildEmbeddingCacheLookupKernel(); BuildEmbeddingCacheUpdateKernel(); + // Build and link rpc operators. BuildRpcOperators(); LinkRpcOperators(); } @@ -326,8 +345,40 @@ bool EmbeddingCachePrefetchActor::UpdateDeviceCache(void *indices, void *update_ return true; } +void EmbeddingCachePrefetchActor::IncreaseGraphStep(const std::string &channel_name) { + if (!running_) { + MS_LOG(EXCEPTION) << "PS embedding cache data processing thread isn't running."; + } + if (graph_step_ >= UINT64_MAX) { + MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") will exceed the maximum value of uint64_t."; + } + if (graph_step_ == 0) { + MS_LOG(INFO) << "Graph running waiting embedding table init begin:" << finish_init_parameters_on_remote_; + std::unique_lock locker(data_mutex_); + data_parser_.wait(locker, [this] { return ((finish_init_parameters_on_remote_ == true) || (running_ == false)); }); + if (!running_) { + MS_LOG(EXCEPTION) << "PS embedding cache data processing thread isn't running."; + } + MS_LOG(INFO) << "Graph running waiting embedding table init end."; + } + graph_step_++; + set_channel_name(channel_name); + if (!PsDataPrefetch::GetInstance().TryWakeChannel(channel_name)) { + MS_LOG(EXCEPTION) << "TryWakeChannel failed, channel name: " << channel_name; + } + data_parser_.notify_one(); +} + void EmbeddingCachePrefetchActor::Run() { - // Note:Need to wait data channel ready. + running_ = true; + + // Wait initialize parameters on remote. + // Prevents the subsequent prefetch cache from failing due to the long initialization time of the large parameter on + // the remote side. + WaitInitParametersOnRemote(); + + // Wait data channel ready. + WaitDataChannelInit(); MS_LOG(INFO) << "Begin prefetching cache."; while (running_) { @@ -728,7 +779,6 @@ bool EmbeddingCachePrefetchActor::PushCacheFromDeviceToLocalHost(const HashTable } MS_ERROR_IF_NULL(embedding_device_cache_); - MS_ERROR_IF_NULL(embedding_device_cache_->cache_); MS_ERROR_IF_NULL(embedding_host_cache_); auto device_cache_device_to_host_index = embedding_device_cache_->device_to_host_index.get(); @@ -798,7 +848,6 @@ bool EmbeddingCachePrefetchActor::PullCacheFromLocalHostToDevice(const HashTable } MS_ERROR_IF_NULL(embedding_device_cache_); - MS_ERROR_IF_NULL(embedding_device_cache_->cache_); MS_ERROR_IF_NULL(embedding_host_cache_); auto host_cache_host_to_device_index = embedding_host_cache_->host_to_device_index.get(); @@ -1126,7 +1175,6 @@ bool EmbeddingCachePrefetchActor::SendToRemote(const std::string &cache_operatio size_t server_rank_id, size_t embedding_dim, const void *keys, size_t keys_len, const void *values, size_t values_len) { MS_ERROR_IF_NULL(keys); - MS_ERROR_IF_NULL(values); // Find sender corresponding to cache operation and parameter key. auto iter = rpc_operators_.find(cache_operation); if (iter == rpc_operators_.end()) { @@ -1138,8 +1186,26 @@ bool EmbeddingCachePrefetchActor::SendToRemote(const std::string &cache_operatio MS_ERROR_IF_NULL(sender); int64_t ids_num = SizeToLong(keys_len / sizeof(int)); - std::vector shapes = {{ids_num}, {ids_num, SizeToLong(embedding_dim)}, {1}}; - std::vector data_types = {kNumberTypeInt32, kNumberTypeFloat32, kNumberTypeInt64}; + ShapeVector ids_shape = {ids_num}; + ShapeVector values_shape; + float fake_value = 0.0; + + if (values == nullptr && values_len == 0) { + values_shape = {1, 1}; + values = &fake_value; + values_len = sizeof(fake_value); + } else { + MS_EXCEPTION_IF_ZERO("embedding_dim", embedding_dim); + int64_t embed_vec_num = SizeToLong(values_len / sizeof(float) / embedding_dim); + if (embed_vec_num != ids_num) { + MS_LOG(EXCEPTION) << "The embedding vector number[" << embed_vec_num << "] shouled be equal to ids number[" + << ids_num << "] which will be send to remote."; + } + values_shape = {embed_vec_num, SizeToLong(embedding_dim)}; + } + + std::vector shapes = {ids_shape, values_shape, {static_cast(1)}}; + std::vector data_types = {kNumberTypeInt32, kNumberTypeFloat32, kNumberTypeInt32}; int32_t service_id = GetCacheOpsServiceId(cache_operation, param_key); AddressPtrList data_list = {std::make_shared
(const_cast(keys), keys_len), @@ -1220,6 +1286,39 @@ bool EmbeddingCachePrefetchActor::RetrieveEmbeddings( return true; } +std::string EmbeddingCachePrefetchActor::channel_name() { + std::lock_guard locker(channel_mutex_); + return channel_name_; +} + +void EmbeddingCachePrefetchActor::set_channel_name(const std::string channel_name) { + if (channel_name_ == channel_name) { + return; + } + std::lock_guard locker(channel_mutex_); + channel_name_ = channel_name; +} + +void EmbeddingCachePrefetchActor::WaitDataChannelInit() { + MS_LOG(INFO) << "Begin wait embedding cache data channel init."; + auto channel = channel_name(); + if (channel.empty()) { + std::unique_lock locker(data_mutex_); + data_parser_.wait(locker, [this] { return !channel_name_.empty() || running_ == false; }); + if (!running_) { + return; + } + } + MS_LOG(INFO) << "End wait embedding cache data channel init."; +} + +void EmbeddingCachePrefetchActor::WaitInitParametersOnRemote() { + std::unique_lock locker(data_mutex_); + // Note: wait to finish embedding lookup from remote. + finish_init_parameters_on_remote_ = true; + data_parser_.notify_one(); +} + void EmbeddingCachePrefetchActor::BuildRpcOperators() { // The cache operation support LookupEmbeddingCache and UpdateEmbeddingCache currently. for (const auto &cache_op : distributed::kEmbeddingCacheOps) { diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.h b/mindspore/ccsrc/runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.h index 9d7c6905906..f4387a15837 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.h +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.h @@ -30,19 +30,12 @@ #include "distributed/rpc/tcp/tcp_client.h" #include "distributed/rpc/tcp/tcp_server.h" #include "utils/hash_map.h" +#include "distributed/embedding_cache/embedding_cache_utils.h" // Note: After the code in ps/ps_cache are removed into runtime/addons/embedding_cache/, // the follow include file and using declaration of ps will be removed. #include "ps/ps_cache/ps_data/ps_data_prefetch.h" -#include "ps/ps_cache/embedding_hash_map.h" -#include "ps/ps_cache/ps_cache_manager.h" #include "ps/ps_context.h" -using mindspore::ps::EmbeddingDeviceCache; -using mindspore::ps::EmbeddingHostCache; -using mindspore::ps::HashTableInfo; -using mindspore::ps::INVALID_INDEX_VALUE; -using mindspore::ps::INVALID_STEP_VALUE; -using mindspore::ps::PsCacheStatisticsInfo; using mindspore::ps::PSContext; using mindspore::ps::PsDataPrefetch; @@ -59,8 +52,17 @@ using ReceiverPtr = std::shared_ptr; using SendRecvPair = std::pair; using SendRecvPairList = std::vector; +using distributed::EmbeddingCacheStatisticsInfo; +using distributed::EmbeddingDeviceCache; +using distributed::EmbeddingHostCache; +using distributed::HashTableInfo; +using distributed::INVALID_INDEX_VALUE; +using distributed::INVALID_STEP_VALUE; + using distributed::cluster::ActorRouteTableProxy; using distributed::cluster::ActorRouteTableProxyPtr; +using distributed::rpc::TCPClient; +using distributed::rpc::TCPServer; // The EmbeddingCachePrefetchActor is used to cache large embedding table scenarios. The cache level is: Device // Cache->Local Host Cache->Remote Cache. This Actor is used to perform Local and Device Cache hit analysis and cache @@ -82,6 +84,9 @@ class EmbeddingCachePrefetchActor : public ActorBase { // Perform local cache hit analysis, prefetch the feature vector corresponding to the next batch into the cache. void Run(); + // Increase the global step of compute graph. + void IncreaseGraphStep(const std::string &channel_name); + // Finalize embedding cache prefetch actor and push latest embedding from local cache to remote cache. void Finalize(); @@ -206,6 +211,19 @@ class EmbeddingCachePrefetchActor : public ActorBase { bool UpdateDeviceCache(void *indices, void *update_value, size_t indices_num, size_t cache_size, size_t embedding_size, void *embedding_cache); + // Get dataset channel name. + std::string channel_name(); + // Set dataset channel name. + void set_channel_name(const std::string channel_name); + + // Wait data channel ready. + void WaitDataChannelInit(); + + // Wait initialize parameters on remote. + // Prevents the subsequent prefetch cache from failing due to the long initialization time of the large parameter on + // the remote side. + void WaitInitParametersOnRemote(); + // Record sender and receiver pairs for different cache operation, server and parameter key. // key: cache operation(such as LookupEmbeddingCache and UpdateEmbeddingCache) // value: sender and receiver pairs for this kind of cache operation. @@ -239,7 +257,7 @@ class EmbeddingCachePrefetchActor : public ActorBase { std::shared_ptr embedding_host_cache_; // Statistics on the cache hit rate of the host and device and the information used to update cache. - PsCacheStatisticsInfo statistics_info_; + EmbeddingCacheStatisticsInfo statistics_info_; // Model parallelism is used between multiple workers, and local_embedding_slice_bounds_ records the feature range // corresponding to the embedding table slice of the process. @@ -255,7 +273,7 @@ class EmbeddingCachePrefetchActor : public ActorBase { std::vector> remote_embedding_slice_bounds_; // Total server number of cluster. - size_t server_num_; + size_t server_num_{0}; // The flag which indicates whether this actor is running to prefetch cache. std::atomic_bool running_{false}; @@ -269,6 +287,11 @@ class EmbeddingCachePrefetchActor : public ActorBase { // Dataset channel name, used in dataset switching scenarios. std::string channel_name_; + // The mutex to access channel_name_. + std::mutex channel_mutex_; + + // The flag indicates whether finish initializing parameters on remote.. + std::atomic_bool finish_init_parameters_on_remote_{false}; // Data parser condition variable for prefetching cache, used to start and synchronize intermediate state for cache // prefetching. @@ -311,13 +334,16 @@ class RpcOperator { // Sender is used to send data to other process. class Sender : public RpcOperator { public: - explicit Sender(const ReceiverPtr &receiver) : server_url_(""), client_(nullptr), receiver_(receiver) {} + Sender() : server_url_(""), client_(nullptr) {} ~Sender(); // Send buffer to peer. bool Send(const std::vector &shapes, const std::vector data_types, const AddressPtrList &data_list) const; + // Set the receiver paired with the sender to get the 'from url' from the receiver. + void set_receiver(const ReceiverPtr &receiver) { receiver_ = receiver; } + // Lookup peer receiver's route and build network connection. bool ConnectServer(); diff --git a/mindspore/ccsrc/runtime/graph_scheduler/embedding_cache_scheduler.cc b/mindspore/ccsrc/runtime/graph_scheduler/embedding_cache_scheduler.cc index c2f2c67a83e..2d7d8a1f9e9 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/embedding_cache_scheduler.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/embedding_cache_scheduler.cc @@ -15,10 +15,10 @@ */ #include "runtime/graph_scheduler/embedding_cache_scheduler.h" -#include "runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.h" - #include #include +#include "runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.h" +#include "utils/ms_context.h" namespace mindspore { namespace runtime {