Add sender and receiver for embedding cache
This commit is contained in:
parent
66c856fc2c
commit
f1cf1139a3
|
@ -52,6 +52,7 @@ struct HashTableInfo {
|
|||
Address device_address{nullptr, 0};
|
||||
std::shared_ptr<float> host_address{nullptr};
|
||||
ParamInitInfo param_init_info_;
|
||||
int32_t param_key_{-1};
|
||||
};
|
||||
|
||||
struct EmbeddingDeviceCache {
|
||||
|
|
|
@ -17,11 +17,10 @@
|
|||
#include "runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.h"
|
||||
#include "backend/common/optimizer/dynamic_shape/dynamic_shape_helper.h"
|
||||
#include "kernel/common_utils.h"
|
||||
#include "runtime/graph_scheduler/actor/rpc/rpc_actor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
using kernel::Address;
|
||||
using kernel::AddressPtrList;
|
||||
using mindspore::session::KernelGraph;
|
||||
|
||||
// One and two dimensional shape placeholder.
|
||||
|
@ -64,6 +63,66 @@ bool InferOpShape(const CNodePtr &kernel) {
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Generate unique inter process edge name, format:
|
||||
// src role + src rank id -> dst role + dst rank id + embedding cache operation + parameter key.
|
||||
std::string GenerateInterProcessEdge(const std::string &src_role, uint32_t src_rank, const std::string &dst_role,
|
||||
uint32_t dst_rank, const std::string &cache_operation, int32_t param_key) {
|
||||
std::string edge = src_role + std::to_string(src_rank) + "->" + dst_role + std::to_string(dst_rank) + "_" +
|
||||
cache_operation + "_" + distributed::kParameterKey + std::to_string(param_key);
|
||||
return edge;
|
||||
}
|
||||
|
||||
ActorRouteTableProxyPtr CreateRouteTableProxy() {
|
||||
auto node = ClusterContext::instance()->node();
|
||||
ActorRouteTableProxyPtr actor_route_table_proxy =
|
||||
std::make_shared<ActorRouteTableProxy>(std::dynamic_pointer_cast<ps::core::AbstractNode>(node));
|
||||
MS_EXCEPTION_IF_NULL(actor_route_table_proxy);
|
||||
return actor_route_table_proxy;
|
||||
}
|
||||
|
||||
// Create a sender and receiver pair,The sender and receiver are paired.
|
||||
// When creating a sender, need to create and specify the receiver paired with it in advance.
|
||||
SendRecvPair CreateSenderReceiverPair(uint32_t worker_rank, uint32_t server_rank, const std::string &cache_operation,
|
||||
int32_t param_key) {
|
||||
// Create sender and receiver pair.
|
||||
ReceiverPtr receiver = std::make_shared<Receiver>();
|
||||
SenderPtr sender = std::make_shared<Sender>(receiver);
|
||||
|
||||
// Set inter process edge
|
||||
receiver->set_inter_process_edge_name(GenerateInterProcessEdge(distributed::kEnvRoleOfServer, server_rank,
|
||||
distributed::kEnvRoleOfWorker, worker_rank,
|
||||
cache_operation, param_key));
|
||||
sender->set_inter_process_edge_name(GenerateInterProcessEdge(distributed::kEnvRoleOfWorker, worker_rank,
|
||||
distributed::kEnvRoleOfServer, server_rank,
|
||||
distributed::kLookupEmbeddingCache, param_key));
|
||||
|
||||
// Set route table proxy.
|
||||
receiver->set_actor_route_table_proxy(CreateRouteTableProxy());
|
||||
sender->set_actor_route_table_proxy(CreateRouteTableProxy());
|
||||
|
||||
return std::make_pair(sender, receiver);
|
||||
}
|
||||
|
||||
// Get cache operation service id which is used to decide which set of cache services to request.
|
||||
// The server side executes the corresponding service according to this id.
|
||||
int64_t GetCacheOpsServiceId(const std::string &cache_operation, int32_t param_key) {
|
||||
static mindspore::HashMap<std::string, int64_t> cache_ops_to_index;
|
||||
if (cache_ops_to_index.empty()) {
|
||||
int64_t cnt = 0;
|
||||
for (const auto &cache_op : distributed::kEmbeddingCacheOps) {
|
||||
cache_ops_to_index[cache_op] = cnt++;
|
||||
}
|
||||
}
|
||||
|
||||
auto iter = cache_ops_to_index.find(cache_operation);
|
||||
if (iter == cache_ops_to_index.end()) {
|
||||
MS_LOG(EXCEPTION) << "Can not find index for cache operation: " << cache_operation;
|
||||
}
|
||||
|
||||
int64_t id = SizeToLong(distributed::kEmbeddingCacheOps.size()) * IntToLong(param_key) + iter->second;
|
||||
return id;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void EmbeddingCachePrefetchActor::Initialize() {
|
||||
|
@ -76,6 +135,8 @@ void EmbeddingCachePrefetchActor::Initialize() {
|
|||
void EmbeddingCachePrefetchActor::Finalize() {
|
||||
embedding_cache_lookup_node_ = nullptr;
|
||||
embedding_cache_update_node_ = nullptr;
|
||||
|
||||
rpc_operators_.clear();
|
||||
}
|
||||
|
||||
void EmbeddingCachePrefetchActor::BuildEmbeddingCacheLookupKernel() {
|
||||
|
@ -607,8 +668,8 @@ bool EmbeddingCachePrefetchActor::PushCacheFromLocalHostToRemote(const HashTable
|
|||
RETURN_IF_FALSE_WITH_LOG(LookupLocalHostCache(embedding_size, swap_indices_size, host_hash_table_addr,
|
||||
host_to_server_index, swap_out_data.data()),
|
||||
"Lookup local host cache failed.");
|
||||
RETURN_IF_FALSE_WITH_LOG(PushEmbeddingsToRemote(host_to_server_ids, swap_indices_size, swap_out_data.data(),
|
||||
swap_out_data.size() * sizeof(float)),
|
||||
RETURN_IF_FALSE_WITH_LOG(PushEmbeddingsToRemote(hash_info.param_key_, host_to_server_ids, swap_indices_size,
|
||||
swap_out_data.data(), swap_out_data.size() * sizeof(float)),
|
||||
"Push embeddings to remote failed.");
|
||||
return true;
|
||||
}
|
||||
|
@ -663,8 +724,9 @@ bool EmbeddingCachePrefetchActor::PullCacheFromRemoteToLocalHost(const HashTable
|
|||
auto embedding_size = hash_info.embedding_size;
|
||||
std::vector<float> lookup_result(swap_indices_size * embedding_size, 0);
|
||||
|
||||
RETURN_IF_FALSE_WITH_LOG(PullEembeddingsFromRemote(server_to_host_ids, swap_indices_size, &lookup_result),
|
||||
"Pull embedding from remote failed.");
|
||||
RETURN_IF_FALSE_WITH_LOG(
|
||||
PullEembeddingsFromRemote(hash_info.param_key_, server_to_host_ids, swap_indices_size, &lookup_result),
|
||||
"Pull embedding from remote failed.");
|
||||
RETURN_IF_FALSE_WITH_LOG(InsertLocalHostCache(embedding_size, IntToSize(swap_indices_size), server_to_host_index,
|
||||
lookup_result.data(), host_hash_table_addr),
|
||||
"Insert local host cache failed.");
|
||||
|
@ -823,15 +885,21 @@ bool EmbeddingCachePrefetchActor::LookupLocalHostCache(size_t embedding_size, si
|
|||
return running_;
|
||||
}
|
||||
|
||||
bool EmbeddingCachePrefetchActor::PullEembeddingsFromRemote(const int *ids, size_t ids_num,
|
||||
bool EmbeddingCachePrefetchActor::PullEembeddingsFromRemote(int32_t param_key, const int *ids, size_t ids_num,
|
||||
std::vector<float> *outputs) {
|
||||
MS_ERROR_IF_NULL(ids);
|
||||
MS_ERROR_IF_NULL(outputs);
|
||||
|
||||
if (ids_num == 0) {
|
||||
MS_LOG(WARNING) << "The ids number is 0";
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int>> slice_ids_list(server_num_);
|
||||
// 1. Partition ids by remote embedding slice bound and get unique ids.
|
||||
RETURN_IF_FALSE_WITH_LOG(PartitionIds(ids, ids_num, &slice_ids_list), "Partition ids failed.");
|
||||
|
||||
size_t embedding_dim = outputs->size() / ids_num;
|
||||
for (size_t i = 0; i < server_num_; i++) {
|
||||
auto &slice_ids = slice_ids_list[i];
|
||||
if (slice_ids.empty()) {
|
||||
|
@ -839,19 +907,28 @@ bool EmbeddingCachePrefetchActor::PullEembeddingsFromRemote(const int *ids, size
|
|||
}
|
||||
|
||||
// 2. Send unique ids to remote to do embedding lookup.
|
||||
RETURN_IF_FALSE_WITH_LOG(SendToRemote(i, slice_ids.data(), slice_ids.size() * sizeof(int)),
|
||||
RETURN_IF_FALSE_WITH_LOG(SendToRemote(distributed::kLookupEmbeddingCache, param_key, i, embedding_dim,
|
||||
slice_ids.data(), slice_ids.size() * sizeof(int)),
|
||||
"Send ids to server failed.");
|
||||
}
|
||||
|
||||
std::vector<std::vector<float>> slice_embeddings_list(server_num_);
|
||||
std::vector<std::unique_ptr<std::vector<char>>> slice_embeddings_list(server_num_);
|
||||
for (size_t i = 0; i < server_num_; i++) {
|
||||
if (slice_ids_list[i].empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// 3. Wait embeddings result.
|
||||
auto &slice_embeddings = slice_embeddings_list[i];
|
||||
RETURN_IF_FALSE_WITH_LOG(WaitRespFromRemote(i, &slice_embeddings), "Wait response from server failed.");
|
||||
slice_embeddings_list[i] = ReceiveFromRemote(distributed::kLookupEmbeddingCache, param_key, i);
|
||||
MS_ERROR_IF_NULL(slice_embeddings_list[i]);
|
||||
// Received embedding integrity check.
|
||||
size_t expected_embedding_size = SizetMulWithOverflowCheck(slice_ids_list[i].size(), embedding_dim);
|
||||
size_t received_embedding_size = slice_embeddings_list[i]->size() / sizeof(float);
|
||||
if (received_embedding_size != expected_embedding_size) {
|
||||
MS_LOG(ERROR) << "Received embedding data from remote is incomplete, expected embedding size: "
|
||||
<< expected_embedding_size << ", but received embedding size: " << received_embedding_size;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Retrieve embeddings by input ids order.
|
||||
|
@ -861,11 +938,16 @@ bool EmbeddingCachePrefetchActor::PullEembeddingsFromRemote(const int *ids, size
|
|||
return true;
|
||||
}
|
||||
|
||||
bool EmbeddingCachePrefetchActor::PushEmbeddingsToRemote(const int *ids, size_t ids_num, const float *embeddings,
|
||||
size_t embeddings_len) {
|
||||
bool EmbeddingCachePrefetchActor::PushEmbeddingsToRemote(int32_t param_key, const int *ids, size_t ids_num,
|
||||
const float *embeddings, size_t embeddings_len) {
|
||||
MS_ERROR_IF_NULL(ids);
|
||||
MS_ERROR_IF_NULL(embeddings);
|
||||
|
||||
if (ids_num == 0) {
|
||||
MS_LOG(WARNING) << "The ids number is 0";
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int>> slice_ids_list(server_num_);
|
||||
std::vector<std::vector<float>> slice_embeddings_list(server_num_);
|
||||
// 1. Partition ids end embeddings by remote embedding slice bound.
|
||||
|
@ -873,6 +955,7 @@ bool EmbeddingCachePrefetchActor::PushEmbeddingsToRemote(const int *ids, size_t
|
|||
PartitionIdsAndEmbeddings(ids, ids_num, embeddings, embeddings_len, &slice_ids_list, &slice_embeddings_list),
|
||||
"Partition ids and embeddings failed.");
|
||||
|
||||
size_t embedding_dim = (embeddings_len / ids_num) / sizeof(float);
|
||||
for (size_t i = 0; i < server_num_; i++) {
|
||||
auto &slice_ids = slice_ids_list[i];
|
||||
if (slice_ids.empty()) {
|
||||
|
@ -881,9 +964,10 @@ bool EmbeddingCachePrefetchActor::PushEmbeddingsToRemote(const int *ids, size_t
|
|||
|
||||
// 2. Send embeddings to remote.
|
||||
auto &slice_embeddings = slice_embeddings_list[i];
|
||||
RETURN_IF_FALSE_WITH_LOG(SendToRemote(i, slice_ids.data(), slice_ids.size() * sizeof(int), slice_embeddings.data(),
|
||||
slice_embeddings.size() * sizeof(float)),
|
||||
"Send ids and embeddings to server failed.");
|
||||
RETURN_IF_FALSE_WITH_LOG(
|
||||
SendToRemote(distributed::kUpdateEmbeddingCache, param_key, i, embedding_dim, slice_ids.data(),
|
||||
slice_ids.size() * sizeof(int), slice_embeddings.data(), slice_embeddings.size() * sizeof(float)),
|
||||
"Send ids and embeddings to server failed.");
|
||||
}
|
||||
|
||||
return true;
|
||||
|
@ -971,21 +1055,51 @@ bool EmbeddingCachePrefetchActor::PartitionIdsAndEmbeddings(const int *ids, size
|
|||
return true;
|
||||
}
|
||||
|
||||
bool EmbeddingCachePrefetchActor::SendToRemote(size_t server_rank_id, const void *keys, size_t keys_len,
|
||||
const void *values, size_t values_len) {
|
||||
// Note: Need to implement the method via send actor.
|
||||
return true;
|
||||
bool EmbeddingCachePrefetchActor::SendToRemote(const std::string &cache_operation, int32_t param_key,
|
||||
size_t server_rank_id, size_t embedding_dim, const void *keys,
|
||||
size_t keys_len, const void *values, size_t values_len) {
|
||||
// Find sender corresponding to cache operation and parameter key.
|
||||
auto iter = rpc_operators_.find(cache_operation);
|
||||
if (iter == rpc_operators_.end()) {
|
||||
MS_LOG(ERROR) << "Can not find rpc operator for cache operation: " << cache_operation;
|
||||
}
|
||||
|
||||
const std::vector<SendRecvPairList> &send_recv_pair_lists = iter->second;
|
||||
const SenderPtr &sender = send_recv_pair_lists[server_rank_id][param_key].first;
|
||||
MS_ERROR_IF_NULL(sender);
|
||||
|
||||
int64_t ids_num = SizeToLong(keys_len / sizeof(int));
|
||||
std::vector<ShapeVector> shapes = {{ids_num}, {ids_num, SizeToLong(embedding_dim)}, {1}};
|
||||
std::vector<TypeId> data_types = {kNumberTypeInt32, kNumberTypeFloat32, kNumberTypeInt64};
|
||||
|
||||
int64_t service_id = GetCacheOpsServiceId(cache_operation, param_key);
|
||||
AddressPtrList data_list = {std::make_shared<Address>(const_cast<void *>(keys), keys_len),
|
||||
std::make_shared<Address>(const_cast<void *>(values), values_len),
|
||||
std::make_shared<Address>(&service_id, sizeof(int64_t))};
|
||||
|
||||
// Send data.
|
||||
return sender->Send(shapes, data_types, data_list);
|
||||
}
|
||||
|
||||
bool EmbeddingCachePrefetchActor::WaitRespFromRemote(size_t server_rank_id, std::vector<float> *outputs) {
|
||||
// Note: Need to implement the method via recv actor.
|
||||
return true;
|
||||
std::unique_ptr<std::vector<char>> EmbeddingCachePrefetchActor::ReceiveFromRemote(const std::string &cache_operation,
|
||||
int32_t param_key,
|
||||
size_t server_rank_id) {
|
||||
// Find receiver corresponding to cache operation and parameter key.
|
||||
auto iter = rpc_operators_.find(cache_operation);
|
||||
if (iter == rpc_operators_.end()) {
|
||||
MS_LOG(ERROR) << "Can not find rpc operator for cache operation: " << cache_operation;
|
||||
}
|
||||
|
||||
const std::vector<SendRecvPairList> &send_recv_pair_lists = iter->second;
|
||||
const ReceiverPtr &receiver = send_recv_pair_lists[server_rank_id][param_key].second;
|
||||
MS_EXCEPTION_IF_NULL(receiver);
|
||||
// Receive data.
|
||||
return receiver->Receive();
|
||||
}
|
||||
|
||||
bool EmbeddingCachePrefetchActor::RetrieveEmbeddings(const int *ids, size_t ids_num,
|
||||
const std::vector<std::vector<int>> &slice_ids_list,
|
||||
const std::vector<std::vector<float>> &slice_embeddings_list,
|
||||
std::vector<float> *outputs) {
|
||||
bool EmbeddingCachePrefetchActor::RetrieveEmbeddings(
|
||||
const int *ids, size_t ids_num, const std::vector<std::vector<int>> &slice_ids_list,
|
||||
const std::vector<std::unique_ptr<std::vector<char>>> &slice_embeddings_list, std::vector<float> *outputs) {
|
||||
MS_ERROR_IF_NULL(ids);
|
||||
MS_ERROR_IF_NULL(outputs);
|
||||
|
||||
|
@ -1003,8 +1117,9 @@ bool EmbeddingCachePrefetchActor::RetrieveEmbeddings(const int *ids, size_t ids_
|
|||
if (slice_ids.empty()) {
|
||||
continue;
|
||||
}
|
||||
const std::vector<float> &slice_embeddings = slice_embeddings_list[i];
|
||||
const float *embeddings_data = slice_embeddings.data();
|
||||
const std::unique_ptr<std::vector<char>> &slice_embeddings = slice_embeddings_list[i];
|
||||
MS_ERROR_IF_NULL(slice_embeddings);
|
||||
const float *embeddings_data = reinterpret_cast<float *>(slice_embeddings->data());
|
||||
for (size_t j = 0; j < slice_ids.size(); j++) {
|
||||
(void)ids_to_addrs.emplace(slice_ids[j], embeddings_data + offset);
|
||||
offset += embedding_dim;
|
||||
|
@ -1035,5 +1150,287 @@ bool EmbeddingCachePrefetchActor::RetrieveEmbeddings(const int *ids, size_t ids_
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void EmbeddingCachePrefetchActor::BuildRpcOperators() {
|
||||
// The cache operation support LookupEmbeddingCache and UpdateEmbeddingCache currently.
|
||||
for (const auto &cache_op : distributed::kEmbeddingCacheOps) {
|
||||
rpc_operators_[cache_op] = std::vector<SendRecvPairList>();
|
||||
rpc_operators_[cache_op].resize(server_num_);
|
||||
}
|
||||
|
||||
auto node = distributed::cluster::ClusterContext::instance()->node();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
uint32_t worker_rank_id = node->rank_id();
|
||||
|
||||
// Create sender and receiver pairs for different cache operation, server and parameter key.
|
||||
for (auto &item : rpc_operators_) {
|
||||
const std::string &cache_op = item.first;
|
||||
std::vector<SendRecvPairList> &send_recv_pair_lists = item.second;
|
||||
for (uint32_t i = 0; i < server_num_; i++) {
|
||||
SendRecvPairList &send_recv_pair_list = send_recv_pair_lists[i];
|
||||
send_recv_pair_list.resize(hash_tables_.size());
|
||||
|
||||
for (const auto &table : hash_tables_) {
|
||||
int32_t key = table.second.param_key_;
|
||||
if (key >= SizeToInt(hash_tables_.size()) || key < 0) {
|
||||
MS_LOG(EXCEPTION) << "Invalid parameter key: " << key;
|
||||
}
|
||||
|
||||
send_recv_pair_list[key] = CreateSenderReceiverPair(worker_rank_id, i, cache_op, key);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void EmbeddingCachePrefetchActor::LinkRpcOperators() {
|
||||
std::vector<SenderPtr> senders;
|
||||
std::vector<ReceiverPtr> receivers;
|
||||
for (const auto &item : rpc_operators_) {
|
||||
const std::vector<SendRecvPairList> &send_recv_pair_lists = item.second;
|
||||
for (const SendRecvPairList &send_recv_pair_list : send_recv_pair_lists) {
|
||||
for (const SendRecvPair &pair : send_recv_pair_list) {
|
||||
senders.push_back(pair.first);
|
||||
receivers.push_back(pair.second);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Must start server and register route table before looking up route and connecting.
|
||||
// Start servers of receiver and register route table.
|
||||
for (auto &receiver : receivers) {
|
||||
MS_EXCEPTION_IF_NULL(receiver);
|
||||
if (!receiver->StartServer()) {
|
||||
MS_LOG(EXCEPTION) << "Failed to start server for the receiver.";
|
||||
}
|
||||
}
|
||||
|
||||
// Lookup route and connect to servers for sender.
|
||||
for (auto &sender : senders) {
|
||||
MS_EXCEPTION_IF_NULL(sender);
|
||||
if (!sender->ConnectServer()) {
|
||||
MS_LOG(EXCEPTION) << "Failed to connect servers for the sender.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool Sender::Send(const std::vector<ShapeVector> &shapes, const std::vector<TypeId> data_types,
|
||||
const AddressPtrList &data_list) const {
|
||||
MS_ERROR_IF_NULL(receiver_);
|
||||
auto message = BuildRpcMessage(shapes, data_types, data_list, receiver_->get_url(), server_url_);
|
||||
MS_ERROR_IF_NULL(message);
|
||||
MS_ERROR_IF_NULL(client_);
|
||||
client_->SendAsync(std::move(message));
|
||||
return true;
|
||||
}
|
||||
|
||||
Sender::~Sender() {
|
||||
if (client_) {
|
||||
client_->Disconnect(server_url_);
|
||||
client_->Finalize();
|
||||
}
|
||||
client_ = nullptr;
|
||||
receiver_ = nullptr;
|
||||
}
|
||||
|
||||
bool Sender::ConnectServer() {
|
||||
client_ = std::make_unique<TCPClient>();
|
||||
MS_ERROR_IF_NULL(client_);
|
||||
if (!client_->Initialize()) {
|
||||
MS_LOG(ERROR) << "Failed to initialize tcp server for send actor.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Lookup peer receiver addresses.
|
||||
MS_ERROR_IF_NULL(route_table_proxy_);
|
||||
auto peer_actor_address = route_table_proxy_->LookupRoute(inter_process_edge_);
|
||||
server_url_ = peer_actor_address.ip() + ":" + std::to_string(peer_actor_address.port());
|
||||
if (!client_->Connect(server_url_)) {
|
||||
MS_LOG(ERROR) << "Failed to connect to server of edge: " << inter_process_edge_ << ", server_url: " << server_url_;
|
||||
return false;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Successfully connect to server " << server_url_
|
||||
<< ", inter process edge name: " << inter_process_edge_;
|
||||
return true;
|
||||
}
|
||||
|
||||
std::unique_ptr<MessageBase> Sender::BuildRpcMessage(const std::vector<ShapeVector> &shapes,
|
||||
const std::vector<TypeId> data_types,
|
||||
const AddressPtrList &data_list, const std::string &from_url,
|
||||
const std::string &to_url) const {
|
||||
std::unique_ptr<MessageBase> message = std::make_unique<MessageBase>();
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(message, nullptr);
|
||||
message->from = AID("", from_url);
|
||||
message->to = AID("", to_url);
|
||||
|
||||
if (shapes.size() != data_list.size()) {
|
||||
MS_LOG(ERROR) << "The shape list size[" << shapes.size() << "] should be equal to data list size["
|
||||
<< data_list.size() << "]";
|
||||
}
|
||||
|
||||
if (data_types.size() != data_list.size()) {
|
||||
MS_LOG(ERROR) << "The date type list size[" << data_types.size() << "] should be equal to data list size["
|
||||
<< data_list.size() << "]";
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < data_list.size(); i++) {
|
||||
const ShapeVector &shape = shapes[i];
|
||||
const AddressPtr &data = data_list[i];
|
||||
const TypeId &type_id = data_types[i];
|
||||
|
||||
rpc::DynamicShapeMessage ds_pb_msg;
|
||||
ds_pb_msg.set_type_id(type_id);
|
||||
*ds_pb_msg.mutable_shape_vector() = {shape.begin(), shape.end()};
|
||||
std::string ds_pb_msg_str = ds_pb_msg.SerializeAsString();
|
||||
|
||||
// Message format:
|
||||
// |RPC_DYNAMIC_SHAPE_DATA | dynamic shape PB data size |---dynamic shape PB data----|---real data----|
|
||||
// 1. The dynamic shape header.
|
||||
message->body.append(kRpcDynamicShapeData);
|
||||
// 2. The size of the protobuf DynamicShapeMessage.
|
||||
size_t ds_pb_msg_size = ds_pb_msg_str.size();
|
||||
message->body.append(reinterpret_cast<char *>(&ds_pb_msg_size), sizeof(ds_pb_msg_size));
|
||||
// 3. Protobuf DynamicShapeMessage.
|
||||
message->body.append(ds_pb_msg_str);
|
||||
// 4. The real data buffer need to be sent.
|
||||
message->body.append(static_cast<char *>(data->addr), data->size);
|
||||
}
|
||||
return message;
|
||||
}
|
||||
|
||||
Receiver::~Receiver() {
|
||||
if (server_) {
|
||||
server_->Finalize();
|
||||
}
|
||||
server_ = nullptr;
|
||||
received_buffer_ = nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<std::vector<char>> Receiver::Receive() {
|
||||
std::unique_lock<std::mutex> locker(received_msg_mtx_);
|
||||
// The maximum time(300 seconds) to wait to receive message.
|
||||
const int64_t longest_time_to_wait = 300;
|
||||
received_msg_cv_.wait_for(locker, std::chrono::seconds(longest_time_to_wait),
|
||||
[this] { return received_msg_.load(); });
|
||||
|
||||
std::unique_ptr<std::vector<char>> output = std::move(received_buffer_);
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
received_msg_ = false;
|
||||
return output;
|
||||
}
|
||||
|
||||
bool Receiver::StartServer() {
|
||||
// 1. Create a tcp server and start listening.
|
||||
server_ = std::make_unique<TCPServer>();
|
||||
MS_EXCEPTION_IF_NULL(server_);
|
||||
if (!server_->Initialize()) {
|
||||
MS_LOG(EXCEPTION) << "Failed to initialize tcp server for recv actor";
|
||||
}
|
||||
ip_ = server_->GetIP();
|
||||
port_ = server_->GetPort();
|
||||
std::string server_url = ip_ + ":" + std::to_string(port_);
|
||||
|
||||
// 2. Set the message handler of the server.
|
||||
server_->SetMessageHandler(std::bind(&Receiver::HandleMessage, this, std::placeholders::_1));
|
||||
|
||||
// 3. Register the server address to route table. The server should not be connected before this step is done.
|
||||
MS_LOG(INFO) << "Start server for receiver. Server address: " << server_url
|
||||
<< ", inter process edge name: " << inter_process_edge_;
|
||||
ActorAddress recv_actor_addresss;
|
||||
recv_actor_addresss.set_actor_id(inter_process_edge_);
|
||||
recv_actor_addresss.set_ip(ip_);
|
||||
recv_actor_addresss.set_port(port_);
|
||||
MS_EXCEPTION_IF_NULL(route_table_proxy_);
|
||||
if (!route_table_proxy_->RegisterRoute(inter_process_edge_, recv_actor_addresss)) {
|
||||
MS_LOG(EXCEPTION) << "Failed to register route for " << inter_process_edge_ << " " << server_url
|
||||
<< " when starting server.";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Receiver::ParseDynamicShapeData(const char *msg_body, size_t msg_len,
|
||||
std::pair<const void *, size_t> *data) const {
|
||||
MS_ERROR_IF_NULL(msg_body);
|
||||
MS_ERROR_IF_NULL(data);
|
||||
// 1. Check whether received data is valid dynamic shape data.
|
||||
size_t dynamic_shape_header_size = strlen(kRpcDynamicShapeData);
|
||||
if (msg_len <= dynamic_shape_header_size) {
|
||||
MS_LOG(ERROR) << "Received data is not dynamic shape, data length: " << msg_len;
|
||||
return false;
|
||||
}
|
||||
std::string msg_dynamic_shape_header(msg_body, dynamic_shape_header_size);
|
||||
if (msg_dynamic_shape_header != kRpcDynamicShapeData) {
|
||||
MS_LOG(ERROR) << "Received data is not dynamic shape, not find dynamic shape header: " << kRpcDynamicShapeData;
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t offset = dynamic_shape_header_size;
|
||||
// 2. Parse the size of dynamic shape serialized protobuf message.
|
||||
if (offset + sizeof(size_t) >= msg_len) {
|
||||
MS_LOG(ERROR) << "Received data is incomplete";
|
||||
return false;
|
||||
}
|
||||
size_t dynamic_shape_pb_size = *(reinterpret_cast<const size_t *>(msg_body + offset));
|
||||
offset += sizeof(size_t);
|
||||
if (offset + dynamic_shape_pb_size >= msg_len) {
|
||||
MS_LOG(ERROR) << "The dynamic shape pb data is incomplete";
|
||||
return false;
|
||||
}
|
||||
|
||||
// 3. Deserialize the dynamic shape serialized protobuf message.
|
||||
rpc::DynamicShapeMessage pb_msg;
|
||||
(void)pb_msg.ParseFromArray(msg_body + offset, dynamic_shape_pb_size);
|
||||
offset += dynamic_shape_pb_size;
|
||||
size_t received_data_len = msg_len - offset;
|
||||
|
||||
// 4. The data integrity check.
|
||||
ShapeVector shapes(pb_msg.shape_vector().begin(), pb_msg.shape_vector().end());
|
||||
TypeId data_type = static_cast<TypeId>(pb_msg.type_id());
|
||||
int64_t expected_data_len = 1;
|
||||
std::vector<size_t> size_t_shapes(shapes.begin(), shapes.end());
|
||||
if (!kernel::GetShapeSize(size_t_shapes, TypeIdToType(data_type), &expected_data_len)) {
|
||||
MS_LOG(ERROR) << "Getting shape size for shape " << size_t_shapes << " failed.";
|
||||
return false;
|
||||
}
|
||||
if (LongToSize(expected_data_len) != received_data_len) {
|
||||
MS_LOG(ERROR) << "Received data is incomplete, expected size: " << expected_data_len
|
||||
<< ", but received data size: " << received_data_len;
|
||||
return false;
|
||||
}
|
||||
// 5. Get real data addr and size.
|
||||
*data = std::make_pair(msg_body + offset, received_data_len);
|
||||
return true;
|
||||
}
|
||||
|
||||
MessageBase *Receiver::HandleMessage(MessageBase *const msg) {
|
||||
if (msg == nullptr) {
|
||||
MS_LOG(WARNING) << "Received message pointer is nullptr";
|
||||
return distributed::rpc::NULL_MSG;
|
||||
}
|
||||
|
||||
const std::string &msg_body = msg->body;
|
||||
// The data pair: <addr of data, size of data>.
|
||||
std::pair<const void *, size_t> real_data;
|
||||
// Get real data addr and size.
|
||||
if (!ParseDynamicShapeData(msg_body.c_str(), msg_body.size(), &real_data)) {
|
||||
MS_LOG(EXCEPTION) << "Parse dynamic shape data failed.";
|
||||
}
|
||||
|
||||
std::unique_lock<std::mutex> locker(received_msg_mtx_);
|
||||
received_buffer_ = std::make_unique<std::vector<char>>();
|
||||
received_buffer_->resize(real_data.second);
|
||||
MS_EXCEPTION_IF_NULL(real_data.first);
|
||||
|
||||
int ret = memcpy_s(received_buffer_->data(), received_buffer_->size(), real_data.first, real_data.second);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "Memcpy for received data failed, errno[" << ret << "]";
|
||||
}
|
||||
|
||||
received_msg_ = true;
|
||||
received_msg_cv_.notify_one();
|
||||
|
||||
delete msg;
|
||||
return distributed::rpc::NULL_MSG;
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -24,10 +24,12 @@
|
|||
#include <utility>
|
||||
|
||||
#include "runtime/graph_scheduler/actor/actor_common.h"
|
||||
#include "runtime/graph_scheduler/actor/rpc/send_actor.h"
|
||||
#include "runtime/graph_scheduler/actor/rpc/recv_actor.h"
|
||||
#include "ir/anf.h"
|
||||
#include "backend/common/session/kernel_graph.h"
|
||||
#include "distributed/cluster/cluster_context.h"
|
||||
#include "distributed/rpc/tcp/tcp_client.h"
|
||||
#include "distributed/rpc/tcp/tcp_server.h"
|
||||
#include "utils/hash_map.h"
|
||||
|
||||
// Note: After the code in ps/ps_cache are removed into runtime/addons/embedding_cache/,
|
||||
// the follow include file and using declaration of ps will be removed.
|
||||
|
@ -46,6 +48,20 @@ using mindspore::ps::PsDataPrefetch;
|
|||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
using kernel::Address;
|
||||
using kernel::AddressPtr;
|
||||
using kernel::AddressPtrList;
|
||||
|
||||
class Sender;
|
||||
class Receiver;
|
||||
using SenderPtr = std::shared_ptr<Sender>;
|
||||
using ReceiverPtr = std::shared_ptr<Receiver>;
|
||||
using SendRecvPair = std::pair<SenderPtr, ReceiverPtr>;
|
||||
using SendRecvPairList = std::vector<SendRecvPair>;
|
||||
|
||||
using distributed::cluster::ActorRouteTableProxy;
|
||||
using distributed::cluster::ActorRouteTableProxyPtr;
|
||||
|
||||
// The EmbeddingCachePrefetchActor is used to cache large embedding table scenarios. The cache level is: Device
|
||||
// Cache->Local Host Cache->Remote Cache. This Actor is used to perform Local and Device Cache hit analysis and cache
|
||||
// prefetching (the feature weights corresponding to the ids of subsequent batches are assigned in advance Prefetching
|
||||
|
@ -59,8 +75,8 @@ class EmbeddingCachePrefetchActor : public ActorBase {
|
|||
~EmbeddingCachePrefetchActor() override = default;
|
||||
|
||||
// Initialize embedding cache prefetch actor.
|
||||
// 1. Build and Link rpc actors between local cache and remote cache.
|
||||
// 2. Build network connection of rpc actors.
|
||||
// 1. Build and Link rpc operators between local cache and remote cache.
|
||||
// 2. Build network connection of rpc operators.
|
||||
void Initialize();
|
||||
|
||||
// Perform local cache hit analysis, prefetch the feature vector corresponding to the next batch into the cache.
|
||||
|
@ -131,9 +147,10 @@ class EmbeddingCachePrefetchActor : public ActorBase {
|
|||
const int *indices_addr, float *output_addr);
|
||||
|
||||
// Lookup embedding from Remote and get embeddings via RPC.
|
||||
bool PullEembeddingsFromRemote(const int *ids, size_t ids_num, std::vector<float> *outputs);
|
||||
bool PullEembeddingsFromRemote(int32_t param_key, const int *ids, size_t ids_num, std::vector<float> *outputs);
|
||||
// Push the local embedding cache that requires evict to the remote.
|
||||
bool PushEmbeddingsToRemote(const int *ids, size_t ids_num, const float *embeddings, size_t embeddings_len);
|
||||
bool PushEmbeddingsToRemote(int32_t param_key, const int *ids, size_t ids_num, const float *embeddings,
|
||||
size_t embeddings_len);
|
||||
|
||||
// Get the id range of each server's embedding table slice.
|
||||
void GetRemoteEmbeddingSliceBound();
|
||||
|
@ -149,20 +166,24 @@ class EmbeddingCachePrefetchActor : public ActorBase {
|
|||
std::vector<std::vector<float>> *slice_embeddings_list);
|
||||
|
||||
// Send content to remote, such as ids or embeddings.
|
||||
bool SendToRemote(size_t server_rank_id, const void *keys, size_t keys_len, const void *values = nullptr,
|
||||
size_t values_len = 0);
|
||||
// The parameter 'cache_operation' is cache operation name such as LookupEmbeddingCache and UpdateEmbeddingCache.
|
||||
bool SendToRemote(const std::string &cache_operation, int32_t param_key, size_t server_rank_id, size_t embedding_dim,
|
||||
const void *keys, size_t keys_len, const void *values = nullptr, size_t values_len = 0);
|
||||
// Wait response of remote and get return result.
|
||||
bool WaitRespFromRemote(size_t server_rank_id, std::vector<float> *outputs);
|
||||
// The parameter 'cache_operation' is cache operation name such as LookupEmbeddingCache and UpdateEmbeddingCache.
|
||||
std::unique_ptr<std::vector<char>> ReceiveFromRemote(const std::string &cache_operation, int32_t param_key,
|
||||
size_t server_rank_id);
|
||||
// Retrieve embeddings by input ids order.
|
||||
bool RetrieveEmbeddings(const int *ids, size_t ids_num, const std::vector<std::vector<int>> &slice_ids_list,
|
||||
const std::vector<std::vector<float>> &slice_embeddings_list, std::vector<float> *outputs);
|
||||
const std::vector<std::unique_ptr<std::vector<char>>> &slice_embeddings_list,
|
||||
std::vector<float> *outputs);
|
||||
|
||||
// The cache prefetch phase may involve RPC communication with the server, implemented through Send Actor and
|
||||
// Recv Actor.
|
||||
// Build rpc actors.
|
||||
void BuildRpcActors();
|
||||
// Link rpc actors by inter-process arrows.
|
||||
void LinkRpcActors();
|
||||
// The cache prefetch phase may involve RPC communication with the server, implemented through Sender and
|
||||
// Receiver.
|
||||
// Build rpc operators.
|
||||
void BuildRpcOperators();
|
||||
// Link rpc operators and build network connection.
|
||||
void LinkRpcOperators();
|
||||
|
||||
// Build a CNode of embedding cache look up kernel(operator name: 'Gather'), which is used to look up local device
|
||||
// embedding cache.
|
||||
|
@ -185,11 +206,10 @@ class EmbeddingCachePrefetchActor : public ActorBase {
|
|||
bool UpdateDeviceCache(void *indices, void *update_value, size_t indices_num, size_t cache_size,
|
||||
size_t embedding_size, void *embedding_cache);
|
||||
|
||||
// Record Send Actor and Recv Actor.
|
||||
// Key: Inter process edge(Parameter name), Value: Send Actor.
|
||||
std::map<std::string, SendActorPtr> send_actors_;
|
||||
// Key: Inter process edge(Parameter name), Value: Recv Actor.
|
||||
std::map<std::string, RecvActorPtr> recv_actors_;
|
||||
// Record sender and receiver pairs for different cache operation, server and parameter key.
|
||||
// key: cache operation(such as LookupEmbeddingCache and UpdateEmbeddingCache)
|
||||
// value: sender and receiver pairs for this kind of cache operation.
|
||||
mindspore::HashMap<std::string, std::vector<SendRecvPairList>> rpc_operators_;
|
||||
|
||||
// The device interface.
|
||||
device::DeviceContext *device_context_;
|
||||
|
@ -264,6 +284,106 @@ class EmbeddingCachePrefetchActor : public ActorBase {
|
|||
bool host_cache_need_wait_graph_{false};
|
||||
};
|
||||
|
||||
// RpcOperator is used to do rpc with other processes in distributed execution.
|
||||
// RpcOperator use inter process edge to identify paired rpc operators uniquely.
|
||||
class RpcOperator {
|
||||
public:
|
||||
RpcOperator() : inter_process_edge_(""), route_table_proxy_(nullptr) {}
|
||||
~RpcOperator() = default;
|
||||
|
||||
// Set the inter-process edge name for rpc operators.
|
||||
void set_inter_process_edge_name(const std::string &edge_name) { inter_process_edge_ = edge_name; }
|
||||
|
||||
// Set the route table proxy for rpc operators.
|
||||
void set_actor_route_table_proxy(const ActorRouteTableProxyPtr &route_table_proxy) {
|
||||
route_table_proxy_ = route_table_proxy;
|
||||
}
|
||||
|
||||
protected:
|
||||
// Unique edge name between rpc operator, format:
|
||||
// src role + src rank id -> dst role + dst rank id + embedding cache operation + parameter key.
|
||||
std::string inter_process_edge_;
|
||||
|
||||
// Route table proxy for buildding network connection between nodes like workers and server.
|
||||
ActorRouteTableProxyPtr route_table_proxy_;
|
||||
};
|
||||
|
||||
// Sender is used to send data to other process.
|
||||
class Sender : public RpcOperator {
|
||||
public:
|
||||
explicit Sender(const ReceiverPtr &receiver) : server_url_(""), client_(nullptr), receiver_(receiver) {}
|
||||
~Sender();
|
||||
|
||||
// Send buffer to peer.
|
||||
bool Send(const std::vector<ShapeVector> &shapes, const std::vector<TypeId> data_types,
|
||||
const AddressPtrList &data_list) const;
|
||||
|
||||
// Lookup peer receiver's route and build network connection.
|
||||
bool ConnectServer();
|
||||
|
||||
private:
|
||||
// Build the MessageBase include dynamic shape protobuf, which will be sent to peer receiver.
|
||||
// The message format is as below:
|
||||
// |--------22 bytes-------|-------sizeof(size_t)-------|-dynamic shape PB data size-| real data size |
|
||||
// |RPC_DYNAMIC_SHAPE_DATA | dynamic shape PB data size |---dynamic shape PB data----|---real data----|
|
||||
// The message.from (from url) must be set.
|
||||
std::unique_ptr<MessageBase> BuildRpcMessage(const std::vector<ShapeVector> &shapes,
|
||||
const std::vector<TypeId> data_types, const AddressPtrList &data_list,
|
||||
const std::string &from_url, const std::string &to_url) const;
|
||||
|
||||
// The url of the peer receiver's tcp server.
|
||||
std::string server_url_;
|
||||
|
||||
std::unique_ptr<TCPClient> client_;
|
||||
|
||||
// The sender and the receiver are used in pairs. The information sent by the sender contains the url of the
|
||||
// corresponding receiver, so a reference to the receiver is maintained in the sender.
|
||||
ReceiverPtr receiver_;
|
||||
};
|
||||
|
||||
// Receiver is used to receive data from other process.
|
||||
class Receiver : public RpcOperator {
|
||||
public:
|
||||
Receiver() : ip_(""), port_(0), server_(nullptr), received_buffer_(nullptr), received_msg_(false) {}
|
||||
~Receiver();
|
||||
|
||||
// Receive message from the peer sender, this interface is a synchronous interface and will wait for the message
|
||||
// until the timeout period is reached.
|
||||
std::unique_ptr<std::vector<char>> Receive();
|
||||
|
||||
// Start receiver server and register this server address to route table in scheduler by proxy.
|
||||
bool StartServer();
|
||||
|
||||
// Get the url of this receiver, format: ip:port.
|
||||
std::string get_url() const { return ip_ + ":" + std::to_string(port_); }
|
||||
|
||||
private:
|
||||
// The message callback of the tcp server.
|
||||
MessageBase *HandleMessage(MessageBase *const msg);
|
||||
|
||||
// Parse the dynamic shape protobuf message. The format is as below:
|
||||
// |--------22 bytes-------|-------sizeof(size_t)-------|-dynamic shape PB data size-| real data size |
|
||||
// |RPC_DYNAMIC_SHAPE_DATA | dynamic shape PB data size |---dynamic shape PB data----|---real data----|
|
||||
// The output parameter 'data' contains real data addr and size.
|
||||
bool ParseDynamicShapeData(const char *msg_body, size_t msg_len, std::pair<const void *, size_t> *data) const;
|
||||
|
||||
// The network address of this receiver. It's generated automatically by rpc module.
|
||||
std::string ip_;
|
||||
uint32_t port_;
|
||||
|
||||
std::unique_ptr<TCPServer> server_;
|
||||
|
||||
// The buffer used save received content of message.
|
||||
std::unique_ptr<std::vector<char>> received_buffer_;
|
||||
|
||||
// The flag indicates whether receive message successfully.
|
||||
std::atomic_bool received_msg_;
|
||||
|
||||
// The interface 'Receive' is a synchronous, use condition variable to block thread and wait for the message.
|
||||
std::condition_variable received_msg_cv_;
|
||||
std::mutex received_msg_mtx_;
|
||||
};
|
||||
|
||||
using EmbeddingCachePrefetchActorPtr = std::shared_ptr<EmbeddingCachePrefetchActor>;
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue