diff --git a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h index ee1841c6d36..311eea61d7d 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h +++ b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h @@ -52,6 +52,7 @@ struct HashTableInfo { Address device_address{nullptr, 0}; std::shared_ptr host_address{nullptr}; ParamInitInfo param_init_info_; + int32_t param_key_{-1}; }; struct EmbeddingDeviceCache { diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.cc b/mindspore/ccsrc/runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.cc index a70253f92fa..a8adedae9d5 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.cc @@ -17,11 +17,10 @@ #include "runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.h" #include "backend/common/optimizer/dynamic_shape/dynamic_shape_helper.h" #include "kernel/common_utils.h" +#include "runtime/graph_scheduler/actor/rpc/rpc_actor.h" namespace mindspore { namespace runtime { -using kernel::Address; -using kernel::AddressPtrList; using mindspore::session::KernelGraph; // One and two dimensional shape placeholder. @@ -64,6 +63,66 @@ bool InferOpShape(const CNodePtr &kernel) { } return true; } + +// Generate unique inter process edge name, format: +// src role + src rank id -> dst role + dst rank id + embedding cache operation + parameter key. +std::string GenerateInterProcessEdge(const std::string &src_role, uint32_t src_rank, const std::string &dst_role, + uint32_t dst_rank, const std::string &cache_operation, int32_t param_key) { + std::string edge = src_role + std::to_string(src_rank) + "->" + dst_role + std::to_string(dst_rank) + "_" + + cache_operation + "_" + distributed::kParameterKey + std::to_string(param_key); + return edge; +} + +ActorRouteTableProxyPtr CreateRouteTableProxy() { + auto node = ClusterContext::instance()->node(); + ActorRouteTableProxyPtr actor_route_table_proxy = + std::make_shared(std::dynamic_pointer_cast(node)); + MS_EXCEPTION_IF_NULL(actor_route_table_proxy); + return actor_route_table_proxy; +} + +// Create a sender and receiver pair,The sender and receiver are paired. +// When creating a sender, need to create and specify the receiver paired with it in advance. +SendRecvPair CreateSenderReceiverPair(uint32_t worker_rank, uint32_t server_rank, const std::string &cache_operation, + int32_t param_key) { + // Create sender and receiver pair. + ReceiverPtr receiver = std::make_shared(); + SenderPtr sender = std::make_shared(receiver); + + // Set inter process edge + receiver->set_inter_process_edge_name(GenerateInterProcessEdge(distributed::kEnvRoleOfServer, server_rank, + distributed::kEnvRoleOfWorker, worker_rank, + cache_operation, param_key)); + sender->set_inter_process_edge_name(GenerateInterProcessEdge(distributed::kEnvRoleOfWorker, worker_rank, + distributed::kEnvRoleOfServer, server_rank, + distributed::kLookupEmbeddingCache, param_key)); + + // Set route table proxy. + receiver->set_actor_route_table_proxy(CreateRouteTableProxy()); + sender->set_actor_route_table_proxy(CreateRouteTableProxy()); + + return std::make_pair(sender, receiver); +} + +// Get cache operation service id which is used to decide which set of cache services to request. +// The server side executes the corresponding service according to this id. +int64_t GetCacheOpsServiceId(const std::string &cache_operation, int32_t param_key) { + static mindspore::HashMap cache_ops_to_index; + if (cache_ops_to_index.empty()) { + int64_t cnt = 0; + for (const auto &cache_op : distributed::kEmbeddingCacheOps) { + cache_ops_to_index[cache_op] = cnt++; + } + } + + auto iter = cache_ops_to_index.find(cache_operation); + if (iter == cache_ops_to_index.end()) { + MS_LOG(EXCEPTION) << "Can not find index for cache operation: " << cache_operation; + } + + int64_t id = SizeToLong(distributed::kEmbeddingCacheOps.size()) * IntToLong(param_key) + iter->second; + return id; +} } // namespace void EmbeddingCachePrefetchActor::Initialize() { @@ -76,6 +135,8 @@ void EmbeddingCachePrefetchActor::Initialize() { void EmbeddingCachePrefetchActor::Finalize() { embedding_cache_lookup_node_ = nullptr; embedding_cache_update_node_ = nullptr; + + rpc_operators_.clear(); } void EmbeddingCachePrefetchActor::BuildEmbeddingCacheLookupKernel() { @@ -607,8 +668,8 @@ bool EmbeddingCachePrefetchActor::PushCacheFromLocalHostToRemote(const HashTable RETURN_IF_FALSE_WITH_LOG(LookupLocalHostCache(embedding_size, swap_indices_size, host_hash_table_addr, host_to_server_index, swap_out_data.data()), "Lookup local host cache failed."); - RETURN_IF_FALSE_WITH_LOG(PushEmbeddingsToRemote(host_to_server_ids, swap_indices_size, swap_out_data.data(), - swap_out_data.size() * sizeof(float)), + RETURN_IF_FALSE_WITH_LOG(PushEmbeddingsToRemote(hash_info.param_key_, host_to_server_ids, swap_indices_size, + swap_out_data.data(), swap_out_data.size() * sizeof(float)), "Push embeddings to remote failed."); return true; } @@ -663,8 +724,9 @@ bool EmbeddingCachePrefetchActor::PullCacheFromRemoteToLocalHost(const HashTable auto embedding_size = hash_info.embedding_size; std::vector lookup_result(swap_indices_size * embedding_size, 0); - RETURN_IF_FALSE_WITH_LOG(PullEembeddingsFromRemote(server_to_host_ids, swap_indices_size, &lookup_result), - "Pull embedding from remote failed."); + RETURN_IF_FALSE_WITH_LOG( + PullEembeddingsFromRemote(hash_info.param_key_, server_to_host_ids, swap_indices_size, &lookup_result), + "Pull embedding from remote failed."); RETURN_IF_FALSE_WITH_LOG(InsertLocalHostCache(embedding_size, IntToSize(swap_indices_size), server_to_host_index, lookup_result.data(), host_hash_table_addr), "Insert local host cache failed."); @@ -823,15 +885,21 @@ bool EmbeddingCachePrefetchActor::LookupLocalHostCache(size_t embedding_size, si return running_; } -bool EmbeddingCachePrefetchActor::PullEembeddingsFromRemote(const int *ids, size_t ids_num, +bool EmbeddingCachePrefetchActor::PullEembeddingsFromRemote(int32_t param_key, const int *ids, size_t ids_num, std::vector *outputs) { MS_ERROR_IF_NULL(ids); MS_ERROR_IF_NULL(outputs); + if (ids_num == 0) { + MS_LOG(WARNING) << "The ids number is 0"; + return true; + } + std::vector> slice_ids_list(server_num_); // 1. Partition ids by remote embedding slice bound and get unique ids. RETURN_IF_FALSE_WITH_LOG(PartitionIds(ids, ids_num, &slice_ids_list), "Partition ids failed."); + size_t embedding_dim = outputs->size() / ids_num; for (size_t i = 0; i < server_num_; i++) { auto &slice_ids = slice_ids_list[i]; if (slice_ids.empty()) { @@ -839,19 +907,28 @@ bool EmbeddingCachePrefetchActor::PullEembeddingsFromRemote(const int *ids, size } // 2. Send unique ids to remote to do embedding lookup. - RETURN_IF_FALSE_WITH_LOG(SendToRemote(i, slice_ids.data(), slice_ids.size() * sizeof(int)), + RETURN_IF_FALSE_WITH_LOG(SendToRemote(distributed::kLookupEmbeddingCache, param_key, i, embedding_dim, + slice_ids.data(), slice_ids.size() * sizeof(int)), "Send ids to server failed."); } - std::vector> slice_embeddings_list(server_num_); + std::vector>> slice_embeddings_list(server_num_); for (size_t i = 0; i < server_num_; i++) { if (slice_ids_list[i].empty()) { continue; } // 3. Wait embeddings result. - auto &slice_embeddings = slice_embeddings_list[i]; - RETURN_IF_FALSE_WITH_LOG(WaitRespFromRemote(i, &slice_embeddings), "Wait response from server failed."); + slice_embeddings_list[i] = ReceiveFromRemote(distributed::kLookupEmbeddingCache, param_key, i); + MS_ERROR_IF_NULL(slice_embeddings_list[i]); + // Received embedding integrity check. + size_t expected_embedding_size = SizetMulWithOverflowCheck(slice_ids_list[i].size(), embedding_dim); + size_t received_embedding_size = slice_embeddings_list[i]->size() / sizeof(float); + if (received_embedding_size != expected_embedding_size) { + MS_LOG(ERROR) << "Received embedding data from remote is incomplete, expected embedding size: " + << expected_embedding_size << ", but received embedding size: " << received_embedding_size; + return false; + } } // 4. Retrieve embeddings by input ids order. @@ -861,11 +938,16 @@ bool EmbeddingCachePrefetchActor::PullEembeddingsFromRemote(const int *ids, size return true; } -bool EmbeddingCachePrefetchActor::PushEmbeddingsToRemote(const int *ids, size_t ids_num, const float *embeddings, - size_t embeddings_len) { +bool EmbeddingCachePrefetchActor::PushEmbeddingsToRemote(int32_t param_key, const int *ids, size_t ids_num, + const float *embeddings, size_t embeddings_len) { MS_ERROR_IF_NULL(ids); MS_ERROR_IF_NULL(embeddings); + if (ids_num == 0) { + MS_LOG(WARNING) << "The ids number is 0"; + return true; + } + std::vector> slice_ids_list(server_num_); std::vector> slice_embeddings_list(server_num_); // 1. Partition ids end embeddings by remote embedding slice bound. @@ -873,6 +955,7 @@ bool EmbeddingCachePrefetchActor::PushEmbeddingsToRemote(const int *ids, size_t PartitionIdsAndEmbeddings(ids, ids_num, embeddings, embeddings_len, &slice_ids_list, &slice_embeddings_list), "Partition ids and embeddings failed."); + size_t embedding_dim = (embeddings_len / ids_num) / sizeof(float); for (size_t i = 0; i < server_num_; i++) { auto &slice_ids = slice_ids_list[i]; if (slice_ids.empty()) { @@ -881,9 +964,10 @@ bool EmbeddingCachePrefetchActor::PushEmbeddingsToRemote(const int *ids, size_t // 2. Send embeddings to remote. auto &slice_embeddings = slice_embeddings_list[i]; - RETURN_IF_FALSE_WITH_LOG(SendToRemote(i, slice_ids.data(), slice_ids.size() * sizeof(int), slice_embeddings.data(), - slice_embeddings.size() * sizeof(float)), - "Send ids and embeddings to server failed."); + RETURN_IF_FALSE_WITH_LOG( + SendToRemote(distributed::kUpdateEmbeddingCache, param_key, i, embedding_dim, slice_ids.data(), + slice_ids.size() * sizeof(int), slice_embeddings.data(), slice_embeddings.size() * sizeof(float)), + "Send ids and embeddings to server failed."); } return true; @@ -971,21 +1055,51 @@ bool EmbeddingCachePrefetchActor::PartitionIdsAndEmbeddings(const int *ids, size return true; } -bool EmbeddingCachePrefetchActor::SendToRemote(size_t server_rank_id, const void *keys, size_t keys_len, - const void *values, size_t values_len) { - // Note: Need to implement the method via send actor. - return true; +bool EmbeddingCachePrefetchActor::SendToRemote(const std::string &cache_operation, int32_t param_key, + size_t server_rank_id, size_t embedding_dim, const void *keys, + size_t keys_len, const void *values, size_t values_len) { + // Find sender corresponding to cache operation and parameter key. + auto iter = rpc_operators_.find(cache_operation); + if (iter == rpc_operators_.end()) { + MS_LOG(ERROR) << "Can not find rpc operator for cache operation: " << cache_operation; + } + + const std::vector &send_recv_pair_lists = iter->second; + const SenderPtr &sender = send_recv_pair_lists[server_rank_id][param_key].first; + MS_ERROR_IF_NULL(sender); + + int64_t ids_num = SizeToLong(keys_len / sizeof(int)); + std::vector shapes = {{ids_num}, {ids_num, SizeToLong(embedding_dim)}, {1}}; + std::vector data_types = {kNumberTypeInt32, kNumberTypeFloat32, kNumberTypeInt64}; + + int64_t service_id = GetCacheOpsServiceId(cache_operation, param_key); + AddressPtrList data_list = {std::make_shared
(const_cast(keys), keys_len), + std::make_shared
(const_cast(values), values_len), + std::make_shared
(&service_id, sizeof(int64_t))}; + + // Send data. + return sender->Send(shapes, data_types, data_list); } -bool EmbeddingCachePrefetchActor::WaitRespFromRemote(size_t server_rank_id, std::vector *outputs) { - // Note: Need to implement the method via recv actor. - return true; +std::unique_ptr> EmbeddingCachePrefetchActor::ReceiveFromRemote(const std::string &cache_operation, + int32_t param_key, + size_t server_rank_id) { + // Find receiver corresponding to cache operation and parameter key. + auto iter = rpc_operators_.find(cache_operation); + if (iter == rpc_operators_.end()) { + MS_LOG(ERROR) << "Can not find rpc operator for cache operation: " << cache_operation; + } + + const std::vector &send_recv_pair_lists = iter->second; + const ReceiverPtr &receiver = send_recv_pair_lists[server_rank_id][param_key].second; + MS_EXCEPTION_IF_NULL(receiver); + // Receive data. + return receiver->Receive(); } -bool EmbeddingCachePrefetchActor::RetrieveEmbeddings(const int *ids, size_t ids_num, - const std::vector> &slice_ids_list, - const std::vector> &slice_embeddings_list, - std::vector *outputs) { +bool EmbeddingCachePrefetchActor::RetrieveEmbeddings( + const int *ids, size_t ids_num, const std::vector> &slice_ids_list, + const std::vector>> &slice_embeddings_list, std::vector *outputs) { MS_ERROR_IF_NULL(ids); MS_ERROR_IF_NULL(outputs); @@ -1003,8 +1117,9 @@ bool EmbeddingCachePrefetchActor::RetrieveEmbeddings(const int *ids, size_t ids_ if (slice_ids.empty()) { continue; } - const std::vector &slice_embeddings = slice_embeddings_list[i]; - const float *embeddings_data = slice_embeddings.data(); + const std::unique_ptr> &slice_embeddings = slice_embeddings_list[i]; + MS_ERROR_IF_NULL(slice_embeddings); + const float *embeddings_data = reinterpret_cast(slice_embeddings->data()); for (size_t j = 0; j < slice_ids.size(); j++) { (void)ids_to_addrs.emplace(slice_ids[j], embeddings_data + offset); offset += embedding_dim; @@ -1035,5 +1150,287 @@ bool EmbeddingCachePrefetchActor::RetrieveEmbeddings(const int *ids, size_t ids_ } return true; } + +void EmbeddingCachePrefetchActor::BuildRpcOperators() { + // The cache operation support LookupEmbeddingCache and UpdateEmbeddingCache currently. + for (const auto &cache_op : distributed::kEmbeddingCacheOps) { + rpc_operators_[cache_op] = std::vector(); + rpc_operators_[cache_op].resize(server_num_); + } + + auto node = distributed::cluster::ClusterContext::instance()->node(); + MS_EXCEPTION_IF_NULL(node); + uint32_t worker_rank_id = node->rank_id(); + + // Create sender and receiver pairs for different cache operation, server and parameter key. + for (auto &item : rpc_operators_) { + const std::string &cache_op = item.first; + std::vector &send_recv_pair_lists = item.second; + for (uint32_t i = 0; i < server_num_; i++) { + SendRecvPairList &send_recv_pair_list = send_recv_pair_lists[i]; + send_recv_pair_list.resize(hash_tables_.size()); + + for (const auto &table : hash_tables_) { + int32_t key = table.second.param_key_; + if (key >= SizeToInt(hash_tables_.size()) || key < 0) { + MS_LOG(EXCEPTION) << "Invalid parameter key: " << key; + } + + send_recv_pair_list[key] = CreateSenderReceiverPair(worker_rank_id, i, cache_op, key); + } + } + } +} + +void EmbeddingCachePrefetchActor::LinkRpcOperators() { + std::vector senders; + std::vector receivers; + for (const auto &item : rpc_operators_) { + const std::vector &send_recv_pair_lists = item.second; + for (const SendRecvPairList &send_recv_pair_list : send_recv_pair_lists) { + for (const SendRecvPair &pair : send_recv_pair_list) { + senders.push_back(pair.first); + receivers.push_back(pair.second); + } + } + } + + // Must start server and register route table before looking up route and connecting. + // Start servers of receiver and register route table. + for (auto &receiver : receivers) { + MS_EXCEPTION_IF_NULL(receiver); + if (!receiver->StartServer()) { + MS_LOG(EXCEPTION) << "Failed to start server for the receiver."; + } + } + + // Lookup route and connect to servers for sender. + for (auto &sender : senders) { + MS_EXCEPTION_IF_NULL(sender); + if (!sender->ConnectServer()) { + MS_LOG(EXCEPTION) << "Failed to connect servers for the sender."; + } + } +} + +bool Sender::Send(const std::vector &shapes, const std::vector data_types, + const AddressPtrList &data_list) const { + MS_ERROR_IF_NULL(receiver_); + auto message = BuildRpcMessage(shapes, data_types, data_list, receiver_->get_url(), server_url_); + MS_ERROR_IF_NULL(message); + MS_ERROR_IF_NULL(client_); + client_->SendAsync(std::move(message)); + return true; +} + +Sender::~Sender() { + if (client_) { + client_->Disconnect(server_url_); + client_->Finalize(); + } + client_ = nullptr; + receiver_ = nullptr; +} + +bool Sender::ConnectServer() { + client_ = std::make_unique(); + MS_ERROR_IF_NULL(client_); + if (!client_->Initialize()) { + MS_LOG(ERROR) << "Failed to initialize tcp server for send actor."; + return false; + } + + // Lookup peer receiver addresses. + MS_ERROR_IF_NULL(route_table_proxy_); + auto peer_actor_address = route_table_proxy_->LookupRoute(inter_process_edge_); + server_url_ = peer_actor_address.ip() + ":" + std::to_string(peer_actor_address.port()); + if (!client_->Connect(server_url_)) { + MS_LOG(ERROR) << "Failed to connect to server of edge: " << inter_process_edge_ << ", server_url: " << server_url_; + return false; + } + + MS_LOG(INFO) << "Successfully connect to server " << server_url_ + << ", inter process edge name: " << inter_process_edge_; + return true; +} + +std::unique_ptr Sender::BuildRpcMessage(const std::vector &shapes, + const std::vector data_types, + const AddressPtrList &data_list, const std::string &from_url, + const std::string &to_url) const { + std::unique_ptr message = std::make_unique(); + MS_ERROR_IF_NULL_W_RET_VAL(message, nullptr); + message->from = AID("", from_url); + message->to = AID("", to_url); + + if (shapes.size() != data_list.size()) { + MS_LOG(ERROR) << "The shape list size[" << shapes.size() << "] should be equal to data list size[" + << data_list.size() << "]"; + } + + if (data_types.size() != data_list.size()) { + MS_LOG(ERROR) << "The date type list size[" << data_types.size() << "] should be equal to data list size[" + << data_list.size() << "]"; + } + + for (size_t i = 0; i < data_list.size(); i++) { + const ShapeVector &shape = shapes[i]; + const AddressPtr &data = data_list[i]; + const TypeId &type_id = data_types[i]; + + rpc::DynamicShapeMessage ds_pb_msg; + ds_pb_msg.set_type_id(type_id); + *ds_pb_msg.mutable_shape_vector() = {shape.begin(), shape.end()}; + std::string ds_pb_msg_str = ds_pb_msg.SerializeAsString(); + + // Message format: + // |RPC_DYNAMIC_SHAPE_DATA | dynamic shape PB data size |---dynamic shape PB data----|---real data----| + // 1. The dynamic shape header. + message->body.append(kRpcDynamicShapeData); + // 2. The size of the protobuf DynamicShapeMessage. + size_t ds_pb_msg_size = ds_pb_msg_str.size(); + message->body.append(reinterpret_cast(&ds_pb_msg_size), sizeof(ds_pb_msg_size)); + // 3. Protobuf DynamicShapeMessage. + message->body.append(ds_pb_msg_str); + // 4. The real data buffer need to be sent. + message->body.append(static_cast(data->addr), data->size); + } + return message; +} + +Receiver::~Receiver() { + if (server_) { + server_->Finalize(); + } + server_ = nullptr; + received_buffer_ = nullptr; +} + +std::unique_ptr> Receiver::Receive() { + std::unique_lock locker(received_msg_mtx_); + // The maximum time(300 seconds) to wait to receive message. + const int64_t longest_time_to_wait = 300; + received_msg_cv_.wait_for(locker, std::chrono::seconds(longest_time_to_wait), + [this] { return received_msg_.load(); }); + + std::unique_ptr> output = std::move(received_buffer_); + MS_EXCEPTION_IF_NULL(output); + received_msg_ = false; + return output; +} + +bool Receiver::StartServer() { + // 1. Create a tcp server and start listening. + server_ = std::make_unique(); + MS_EXCEPTION_IF_NULL(server_); + if (!server_->Initialize()) { + MS_LOG(EXCEPTION) << "Failed to initialize tcp server for recv actor"; + } + ip_ = server_->GetIP(); + port_ = server_->GetPort(); + std::string server_url = ip_ + ":" + std::to_string(port_); + + // 2. Set the message handler of the server. + server_->SetMessageHandler(std::bind(&Receiver::HandleMessage, this, std::placeholders::_1)); + + // 3. Register the server address to route table. The server should not be connected before this step is done. + MS_LOG(INFO) << "Start server for receiver. Server address: " << server_url + << ", inter process edge name: " << inter_process_edge_; + ActorAddress recv_actor_addresss; + recv_actor_addresss.set_actor_id(inter_process_edge_); + recv_actor_addresss.set_ip(ip_); + recv_actor_addresss.set_port(port_); + MS_EXCEPTION_IF_NULL(route_table_proxy_); + if (!route_table_proxy_->RegisterRoute(inter_process_edge_, recv_actor_addresss)) { + MS_LOG(EXCEPTION) << "Failed to register route for " << inter_process_edge_ << " " << server_url + << " when starting server."; + } + return true; +} + +bool Receiver::ParseDynamicShapeData(const char *msg_body, size_t msg_len, + std::pair *data) const { + MS_ERROR_IF_NULL(msg_body); + MS_ERROR_IF_NULL(data); + // 1. Check whether received data is valid dynamic shape data. + size_t dynamic_shape_header_size = strlen(kRpcDynamicShapeData); + if (msg_len <= dynamic_shape_header_size) { + MS_LOG(ERROR) << "Received data is not dynamic shape, data length: " << msg_len; + return false; + } + std::string msg_dynamic_shape_header(msg_body, dynamic_shape_header_size); + if (msg_dynamic_shape_header != kRpcDynamicShapeData) { + MS_LOG(ERROR) << "Received data is not dynamic shape, not find dynamic shape header: " << kRpcDynamicShapeData; + return false; + } + + size_t offset = dynamic_shape_header_size; + // 2. Parse the size of dynamic shape serialized protobuf message. + if (offset + sizeof(size_t) >= msg_len) { + MS_LOG(ERROR) << "Received data is incomplete"; + return false; + } + size_t dynamic_shape_pb_size = *(reinterpret_cast(msg_body + offset)); + offset += sizeof(size_t); + if (offset + dynamic_shape_pb_size >= msg_len) { + MS_LOG(ERROR) << "The dynamic shape pb data is incomplete"; + return false; + } + + // 3. Deserialize the dynamic shape serialized protobuf message. + rpc::DynamicShapeMessage pb_msg; + (void)pb_msg.ParseFromArray(msg_body + offset, dynamic_shape_pb_size); + offset += dynamic_shape_pb_size; + size_t received_data_len = msg_len - offset; + + // 4. The data integrity check. + ShapeVector shapes(pb_msg.shape_vector().begin(), pb_msg.shape_vector().end()); + TypeId data_type = static_cast(pb_msg.type_id()); + int64_t expected_data_len = 1; + std::vector size_t_shapes(shapes.begin(), shapes.end()); + if (!kernel::GetShapeSize(size_t_shapes, TypeIdToType(data_type), &expected_data_len)) { + MS_LOG(ERROR) << "Getting shape size for shape " << size_t_shapes << " failed."; + return false; + } + if (LongToSize(expected_data_len) != received_data_len) { + MS_LOG(ERROR) << "Received data is incomplete, expected size: " << expected_data_len + << ", but received data size: " << received_data_len; + return false; + } + // 5. Get real data addr and size. + *data = std::make_pair(msg_body + offset, received_data_len); + return true; +} + +MessageBase *Receiver::HandleMessage(MessageBase *const msg) { + if (msg == nullptr) { + MS_LOG(WARNING) << "Received message pointer is nullptr"; + return distributed::rpc::NULL_MSG; + } + + const std::string &msg_body = msg->body; + // The data pair: . + std::pair real_data; + // Get real data addr and size. + if (!ParseDynamicShapeData(msg_body.c_str(), msg_body.size(), &real_data)) { + MS_LOG(EXCEPTION) << "Parse dynamic shape data failed."; + } + + std::unique_lock locker(received_msg_mtx_); + received_buffer_ = std::make_unique>(); + received_buffer_->resize(real_data.second); + MS_EXCEPTION_IF_NULL(real_data.first); + + int ret = memcpy_s(received_buffer_->data(), received_buffer_->size(), real_data.first, real_data.second); + if (ret != 0) { + MS_LOG(EXCEPTION) << "Memcpy for received data failed, errno[" << ret << "]"; + } + + received_msg_ = true; + received_msg_cv_.notify_one(); + + delete msg; + return distributed::rpc::NULL_MSG; +} } // namespace runtime } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.h b/mindspore/ccsrc/runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.h index 57070ca6af8..9d7c6905906 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.h +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.h @@ -24,10 +24,12 @@ #include #include "runtime/graph_scheduler/actor/actor_common.h" -#include "runtime/graph_scheduler/actor/rpc/send_actor.h" -#include "runtime/graph_scheduler/actor/rpc/recv_actor.h" #include "ir/anf.h" #include "backend/common/session/kernel_graph.h" +#include "distributed/cluster/cluster_context.h" +#include "distributed/rpc/tcp/tcp_client.h" +#include "distributed/rpc/tcp/tcp_server.h" +#include "utils/hash_map.h" // Note: After the code in ps/ps_cache are removed into runtime/addons/embedding_cache/, // the follow include file and using declaration of ps will be removed. @@ -46,6 +48,20 @@ using mindspore::ps::PsDataPrefetch; namespace mindspore { namespace runtime { +using kernel::Address; +using kernel::AddressPtr; +using kernel::AddressPtrList; + +class Sender; +class Receiver; +using SenderPtr = std::shared_ptr; +using ReceiverPtr = std::shared_ptr; +using SendRecvPair = std::pair; +using SendRecvPairList = std::vector; + +using distributed::cluster::ActorRouteTableProxy; +using distributed::cluster::ActorRouteTableProxyPtr; + // The EmbeddingCachePrefetchActor is used to cache large embedding table scenarios. The cache level is: Device // Cache->Local Host Cache->Remote Cache. This Actor is used to perform Local and Device Cache hit analysis and cache // prefetching (the feature weights corresponding to the ids of subsequent batches are assigned in advance Prefetching @@ -59,8 +75,8 @@ class EmbeddingCachePrefetchActor : public ActorBase { ~EmbeddingCachePrefetchActor() override = default; // Initialize embedding cache prefetch actor. - // 1. Build and Link rpc actors between local cache and remote cache. - // 2. Build network connection of rpc actors. + // 1. Build and Link rpc operators between local cache and remote cache. + // 2. Build network connection of rpc operators. void Initialize(); // Perform local cache hit analysis, prefetch the feature vector corresponding to the next batch into the cache. @@ -131,9 +147,10 @@ class EmbeddingCachePrefetchActor : public ActorBase { const int *indices_addr, float *output_addr); // Lookup embedding from Remote and get embeddings via RPC. - bool PullEembeddingsFromRemote(const int *ids, size_t ids_num, std::vector *outputs); + bool PullEembeddingsFromRemote(int32_t param_key, const int *ids, size_t ids_num, std::vector *outputs); // Push the local embedding cache that requires evict to the remote. - bool PushEmbeddingsToRemote(const int *ids, size_t ids_num, const float *embeddings, size_t embeddings_len); + bool PushEmbeddingsToRemote(int32_t param_key, const int *ids, size_t ids_num, const float *embeddings, + size_t embeddings_len); // Get the id range of each server's embedding table slice. void GetRemoteEmbeddingSliceBound(); @@ -149,20 +166,24 @@ class EmbeddingCachePrefetchActor : public ActorBase { std::vector> *slice_embeddings_list); // Send content to remote, such as ids or embeddings. - bool SendToRemote(size_t server_rank_id, const void *keys, size_t keys_len, const void *values = nullptr, - size_t values_len = 0); + // The parameter 'cache_operation' is cache operation name such as LookupEmbeddingCache and UpdateEmbeddingCache. + bool SendToRemote(const std::string &cache_operation, int32_t param_key, size_t server_rank_id, size_t embedding_dim, + const void *keys, size_t keys_len, const void *values = nullptr, size_t values_len = 0); // Wait response of remote and get return result. - bool WaitRespFromRemote(size_t server_rank_id, std::vector *outputs); + // The parameter 'cache_operation' is cache operation name such as LookupEmbeddingCache and UpdateEmbeddingCache. + std::unique_ptr> ReceiveFromRemote(const std::string &cache_operation, int32_t param_key, + size_t server_rank_id); // Retrieve embeddings by input ids order. bool RetrieveEmbeddings(const int *ids, size_t ids_num, const std::vector> &slice_ids_list, - const std::vector> &slice_embeddings_list, std::vector *outputs); + const std::vector>> &slice_embeddings_list, + std::vector *outputs); - // The cache prefetch phase may involve RPC communication with the server, implemented through Send Actor and - // Recv Actor. - // Build rpc actors. - void BuildRpcActors(); - // Link rpc actors by inter-process arrows. - void LinkRpcActors(); + // The cache prefetch phase may involve RPC communication with the server, implemented through Sender and + // Receiver. + // Build rpc operators. + void BuildRpcOperators(); + // Link rpc operators and build network connection. + void LinkRpcOperators(); // Build a CNode of embedding cache look up kernel(operator name: 'Gather'), which is used to look up local device // embedding cache. @@ -185,11 +206,10 @@ class EmbeddingCachePrefetchActor : public ActorBase { bool UpdateDeviceCache(void *indices, void *update_value, size_t indices_num, size_t cache_size, size_t embedding_size, void *embedding_cache); - // Record Send Actor and Recv Actor. - // Key: Inter process edge(Parameter name), Value: Send Actor. - std::map send_actors_; - // Key: Inter process edge(Parameter name), Value: Recv Actor. - std::map recv_actors_; + // Record sender and receiver pairs for different cache operation, server and parameter key. + // key: cache operation(such as LookupEmbeddingCache and UpdateEmbeddingCache) + // value: sender and receiver pairs for this kind of cache operation. + mindspore::HashMap> rpc_operators_; // The device interface. device::DeviceContext *device_context_; @@ -264,6 +284,106 @@ class EmbeddingCachePrefetchActor : public ActorBase { bool host_cache_need_wait_graph_{false}; }; +// RpcOperator is used to do rpc with other processes in distributed execution. +// RpcOperator use inter process edge to identify paired rpc operators uniquely. +class RpcOperator { + public: + RpcOperator() : inter_process_edge_(""), route_table_proxy_(nullptr) {} + ~RpcOperator() = default; + + // Set the inter-process edge name for rpc operators. + void set_inter_process_edge_name(const std::string &edge_name) { inter_process_edge_ = edge_name; } + + // Set the route table proxy for rpc operators. + void set_actor_route_table_proxy(const ActorRouteTableProxyPtr &route_table_proxy) { + route_table_proxy_ = route_table_proxy; + } + + protected: + // Unique edge name between rpc operator, format: + // src role + src rank id -> dst role + dst rank id + embedding cache operation + parameter key. + std::string inter_process_edge_; + + // Route table proxy for buildding network connection between nodes like workers and server. + ActorRouteTableProxyPtr route_table_proxy_; +}; + +// Sender is used to send data to other process. +class Sender : public RpcOperator { + public: + explicit Sender(const ReceiverPtr &receiver) : server_url_(""), client_(nullptr), receiver_(receiver) {} + ~Sender(); + + // Send buffer to peer. + bool Send(const std::vector &shapes, const std::vector data_types, + const AddressPtrList &data_list) const; + + // Lookup peer receiver's route and build network connection. + bool ConnectServer(); + + private: + // Build the MessageBase include dynamic shape protobuf, which will be sent to peer receiver. + // The message format is as below: + // |--------22 bytes-------|-------sizeof(size_t)-------|-dynamic shape PB data size-| real data size | + // |RPC_DYNAMIC_SHAPE_DATA | dynamic shape PB data size |---dynamic shape PB data----|---real data----| + // The message.from (from url) must be set. + std::unique_ptr BuildRpcMessage(const std::vector &shapes, + const std::vector data_types, const AddressPtrList &data_list, + const std::string &from_url, const std::string &to_url) const; + + // The url of the peer receiver's tcp server. + std::string server_url_; + + std::unique_ptr client_; + + // The sender and the receiver are used in pairs. The information sent by the sender contains the url of the + // corresponding receiver, so a reference to the receiver is maintained in the sender. + ReceiverPtr receiver_; +}; + +// Receiver is used to receive data from other process. +class Receiver : public RpcOperator { + public: + Receiver() : ip_(""), port_(0), server_(nullptr), received_buffer_(nullptr), received_msg_(false) {} + ~Receiver(); + + // Receive message from the peer sender, this interface is a synchronous interface and will wait for the message + // until the timeout period is reached. + std::unique_ptr> Receive(); + + // Start receiver server and register this server address to route table in scheduler by proxy. + bool StartServer(); + + // Get the url of this receiver, format: ip:port. + std::string get_url() const { return ip_ + ":" + std::to_string(port_); } + + private: + // The message callback of the tcp server. + MessageBase *HandleMessage(MessageBase *const msg); + + // Parse the dynamic shape protobuf message. The format is as below: + // |--------22 bytes-------|-------sizeof(size_t)-------|-dynamic shape PB data size-| real data size | + // |RPC_DYNAMIC_SHAPE_DATA | dynamic shape PB data size |---dynamic shape PB data----|---real data----| + // The output parameter 'data' contains real data addr and size. + bool ParseDynamicShapeData(const char *msg_body, size_t msg_len, std::pair *data) const; + + // The network address of this receiver. It's generated automatically by rpc module. + std::string ip_; + uint32_t port_; + + std::unique_ptr server_; + + // The buffer used save received content of message. + std::unique_ptr> received_buffer_; + + // The flag indicates whether receive message successfully. + std::atomic_bool received_msg_; + + // The interface 'Receive' is a synchronous, use condition variable to block thread and wait for the message. + std::condition_variable received_msg_cv_; + std::mutex received_msg_mtx_; +}; + using EmbeddingCachePrefetchActorPtr = std::shared_ptr; } // namespace runtime } // namespace mindspore