bugfix for embedding cache

This commit is contained in:
lizhenyu 2022-06-09 15:14:27 +08:00
parent 7fe823582b
commit f78d645b69
16 changed files with 308 additions and 49 deletions

View File

@ -28,6 +28,7 @@
"mindspore/mindspore/core/load_mindir/anf_model_parser.cc" "stlIfStrFind"
"mindspore/mindspore/ccsrc/frontend/optimizer/ad/kpynative.cc" "containerOutOfBounds"
"mindspore/mindspore/ccsrc/pipeline/jit/action.cc" "unreadVariable"
"mindspore/mindspore/ccsrc/plugin/device/gpu/hal/hardware/gpu_device_context.c" "unreadVariable"
# MindData
"mindspore/mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.cc" "useStlAlgorithm"

View File

@ -513,6 +513,10 @@ bool Connection::ParseMessage() {
total_recv_len -= recvLen;
return false;
}
if (!SetUrlForRecvMessage()) {
MS_LOG(ERROR) << "Set url info for recv message failed.";
return false;
}
recv_state = State::kMsgHeader;
break;
default:
@ -521,6 +525,24 @@ bool Connection::ParseMessage() {
return true;
}
bool Connection::SetUrlForRecvMessage() {
auto recv_from_separator_pos = recv_from.find('@');
auto recv_to_separator_pos = recv_to.find('@');
if (recv_from_separator_pos == std::string::npos && recv_to_separator_pos == std::string::npos) {
MS_LOG(ERROR) << "Invalid message format, can not find separator '@'";
return false;
}
std::string from_name = recv_from.substr(0, recv_from_separator_pos);
std::string from_url = recv_from.substr(recv_from_separator_pos + 1);
std::string to_name = recv_to.substr(0, recv_to_separator_pos);
std::string to_url = recv_to.substr(recv_to_separator_pos + 1);
recv_message->from = AID(from_name, from_url);
recv_message->to = AID(to_name, to_url);
return true;
}
void Connection::ReorderHeader(MessageHeader *header) const {
header->name_len = ntohl(header->name_len);
header->to_len = ntohl(header->to_len);

View File

@ -204,6 +204,9 @@ struct Connection {
// Parse message from socket recv buffer.
bool ParseMessage();
// After ParseMessage, set from url and to url into recv message.
bool SetUrlForRecvMessage();
// Make a http message based on given input message.
std::string GenerateHttpMessage(MessageBase *msg);

View File

@ -65,6 +65,26 @@ void PsEmbeddingCacheInserter::GetEmbeddingLookupNodes() {
});
}
void PsEmbeddingCacheInserter::GetCacheEnableParameters() {
MS_EXCEPTION_IF_NULL(root_graph_);
const std::vector<AnfNodePtr> &parameters = root_graph_->parameters();
auto params_size = parameters.size();
for (size_t i = 0; i < params_size; ++i) {
MS_EXCEPTION_IF_NULL(parameters[i]);
if (!parameters[i]->isa<Parameter>()) {
MS_LOG(EXCEPTION) << "The node with name: " << parameters[i]->fullname_with_scope() << "is not a Parameter.";
}
ParameterPtr param = parameters[i]->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param);
auto param_info = param->param_info();
if (param_info && param_info->key() != -1 && param_info->cache_enable()) {
keys_to_params_[param_info->key()] = param;
MS_LOG(INFO) << "Parameter[" << param->fullname_with_scope() << "], key[" << param_info->key() << "]";
}
}
}
void PsEmbeddingCacheInserter::SetNodeAttr(const CNodePtr &node, const std::string &node_role) const {
MS_EXCEPTION_IF_NULL(node);
@ -107,7 +127,7 @@ void PsEmbeddingCacheInserter::SetSendNodeAttr(const CNodePtr &send_node, int32_
dst_ranks.push_back(i);
dst_roles.push_back(dst_role);
// Unique edge name: src role + src rank id -> dst role + dst rank id +embedding cache operation + parameter key.
inter_process_edges.push_back(distributed::kEnvRoleOfServer + std::to_string(rank_id_) + "->" + dst_role +
inter_process_edges.push_back(distributed::kEnvRoleOfPServer + std::to_string(rank_id_) + "->" + dst_role +
std::to_string(i) + "_" + embedding_cache_op + "_" + distributed::kParameterKey +
std::to_string(param_key));
}
@ -118,6 +138,7 @@ void PsEmbeddingCacheInserter::SetSendNodeAttr(const CNodePtr &send_node, int32_
common::AnfAlgo::SetNodeAttr(kAttrSendDstNodeName, MakeValue(std::string(kEmbeddingLocalCacheNode)), send_node);
common::AnfAlgo::SetNodeAttr(kAttrInterProcessEdgeNames, MakeValue(inter_process_edges), send_node);
common::AnfAlgo::SetNodeAttr(kAttrIsMuxRpcKernel, MakeValue(true), send_node);
}
void PsEmbeddingCacheInserter::SetRecvNodeAttr(const CNodePtr &recv_node, const std::string &src_role) const {
@ -139,7 +160,7 @@ void PsEmbeddingCacheInserter::SetRecvNodeAttr(const CNodePtr &recv_node, const
src_roles.push_back(src_role);
// Unique edge name: src role + src rank id -> dst role + dst rank id + embedding cache operation + parameter
// key.
inter_process_edges.push_back(src_role + std::to_string(i) + "->" + distributed::kEnvRoleOfServer +
inter_process_edges.push_back(src_role + std::to_string(i) + "->" + distributed::kEnvRoleOfPServer +
std::to_string(rank_id_) + "_" + distributed::kEmbeddingCacheOps[k] + "_" +
distributed::kParameterKey + std::to_string(param_key));
}
@ -150,7 +171,9 @@ void PsEmbeddingCacheInserter::SetRecvNodeAttr(const CNodePtr &recv_node, const
common::AnfAlgo::SetNodeAttr(kAttrRecvSrcRoles, MakeValue(src_roles), recv_node);
common::AnfAlgo::SetNodeAttr(kAttrRecvSrcNodeName, MakeValue(std::string(kEmbeddingLocalCacheNode)), recv_node);
common::AnfAlgo::SetNodeAttr(kAttrRecvDstNodeName, MakeValue(std::string(kEmbeddingRemoteCacheNode)), recv_node);
common::AnfAlgo::SetNodeAttr(kAttrInterProcessEdgeNames, MakeValue(inter_process_edges), recv_node);
common::AnfAlgo::SetNodeAttr(kAttrIsMuxRpcKernel, MakeValue(true), recv_node);
}
CNodePtr PsEmbeddingCacheInserter::CreateReturnNode(const FuncGraphPtr graph, const AnfNodePtr &output_node) const {
@ -421,6 +444,9 @@ bool PsEmbeddingCacheInserter::Run() {
// Get EmbeddingLookup nodes which are executed on server from origin function graph.
GetEmbeddingLookupNodes();
// Get parameters enabled embedding cache of origin function graph.
GetCacheEnableParameters();
// Construct the embedding cache graph of server.
RETURN_IF_FALSE_WITH_LOG(ConstructEmbeddingCacheGraph(), "Construct embedding cache graph failed.");

View File

@ -21,7 +21,6 @@
#include <map>
#include <vector>
#include "utils/hash_map.h"
#include "ir/anf.h"
#include "distributed/constants.h"
@ -104,7 +103,7 @@ class PsEmbeddingCacheInserter {
// Record parameters enabled embedding cache of origin function graph.
// Key: parameter key, Value: ParameterPtr
mindspore::HashMap<int32_t, ParameterPtr> keys_to_params_;
std::map<int32_t, ParameterPtr> keys_to_params_;
// Record EmbeddingLookup nodes which are executed on server from origin function graph.
// Key: shape of EmbeddingLookup node, Value: EmbeddingLookup AnfNodePtr.

View File

@ -580,6 +580,7 @@ constexpr auto kAttrRecvSrcRanks = "recv_src_ranks";
constexpr auto kAttrRecvSrcRoles = "recv_src_roles";
constexpr auto kAttrInterProcessEdgeNames = "inter_process_edge_names";
constexpr auto kAttrInterProcessEdgeLabel = "inter_process_edge_label";
constexpr auto kAttrIsMuxRpcKernel = "is_mux_rpc_kernel";
constexpr auto kAttrForwardOpOutputId = "forward_op_output_id";
constexpr auto kAttrGroupRankIds = "group_rank_ids";
constexpr auto kAttrReuseCommunication = "reuse_communication_node";

View File

@ -1423,6 +1423,12 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc
MS_EXCEPTION_IF_NULL(backend);
// The data set graph compiling and running of mindRT.
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
#ifdef WITH_BACKEND
if (ps::PSContext::instance()->is_worker() && ps::PSContext::instance()->cache_enable()) {
ps::PsDataPrefetch::GetInstance().CreateDataChannel(queue_name, LongToSize(size));
}
#endif
const auto &mindrt_backend = std::dynamic_pointer_cast<compile::MindRTBackend>(backend);
MS_EXCEPTION_IF_NULL(mindrt_backend);
auto &actor_info = mindrt_backend->CompileGraphs(func_graph);

View File

@ -361,6 +361,61 @@ void SetKernelInfoBeforeCreateKernel(const std::vector<CNodePtr> &nodes) {
}
}
}
// Check whether mutex exists for a stream.
std::pair<bool, std::mutex *> CheckStreamMutexExist(
const void *stream, const mindspore::HashMap<const void *, std::shared_ptr<std::mutex>> &mtxs_for_streams,
std::shared_mutex *shd_mtx) {
MS_EXCEPTION_IF_NULL(stream);
MS_EXCEPTION_IF_NULL(shd_mtx);
std::shared_lock<std::shared_mutex> shd_lock(*shd_mtx);
auto iter = mtxs_for_streams.find(stream);
if (iter != mtxs_for_streams.end()) {
MS_EXCEPTION_IF_NULL(iter->second);
return std::make_pair(true, iter->second.get());
}
return std::make_pair(false, nullptr);
}
// Create a mutex for stream.
std::mutex *CreateStreamMutex(const void *stream, std::shared_mutex *shd_mtx,
mindspore::HashMap<const void *, std::shared_ptr<std::mutex>> *mtxs_for_streams) {
MS_EXCEPTION_IF_NULL(stream);
MS_EXCEPTION_IF_NULL(shd_mtx);
MS_EXCEPTION_IF_NULL(mtxs_for_streams);
std::unique_lock<std::shared_mutex> unq_lock(*shd_mtx);
auto ret_pair = mtxs_for_streams->emplace(stream, std::make_shared<std::mutex>());
MS_EXCEPTION_IF_NULL(ret_pair.first->second);
return ret_pair.first->second.get();
}
// The launch kernel is thread-unsafe, and the behavior of delivering the kernel launch to the same stream requires
// lock protection, need to create a separate lock for each stream.
// for GPU, The cublas handle is not thread safety specifically, it is not recommended that multiple threads access the
// same cublas handle at the same time, so need the launch mutex when multiple threads launch the cublas kernels.
std::lock_guard<std::mutex> LockLaunchKernel(const void *stream) {
MS_EXCEPTION_IF_NULL(stream);
// Read-write lock for accessing mtxs_for_streams map.
// When the lock of each stream is created, mtxs_for_streams can be accessed concurrently to improve performance.
static std::shared_mutex shd_mtx;
static mindspore::HashMap<const void *, std::shared_ptr<std::mutex>> mtxs_for_streams;
std::mutex *stream_mtx;
// Check whether mutex exists for a stream.
std::pair<bool, std::mutex *> ret_pair = CheckStreamMutexExist(stream, mtxs_for_streams, &shd_mtx);
if (ret_pair.first) {
stream_mtx = ret_pair.second;
} else {
// Create a mutex for stream.
stream_mtx = CreateStreamMutex(stream, &shd_mtx, &mtxs_for_streams);
}
MS_EXCEPTION_IF_NULL(stream_mtx);
// Lock kernel launch for the stream.
return std::lock_guard<std::mutex>(*stream_mtx);
}
} // namespace
void GPUDeviceContext::OptimizeGraph(const FuncGraphPtr &graph) const {
@ -462,20 +517,23 @@ bool GPUDeviceContext::LaunchKernel(const CNodePtr &kernel, const std::vector<Ad
}
bool ret = true;
auto stream = GetLaunchKernelStream(kernel);
MS_EXCEPTION_IF_NULL(stream);
#ifndef ENABLE_SECURITY
const auto &profiler_inst = profiler::gpu::GPUProfiler::GetInstance();
MS_EXCEPTION_IF_NULL(profiler_inst);
if (!profiler_inst->GetEnableFlag()) {
#endif
std::lock_guard<std::mutex> locker(launch_mutex_);
auto lock = LockLaunchKernel(stream);
MS_LOG(DEBUG) << "Launch kernel: " << kernel->fullname_with_scope();
ret = DoLaunchKernel(kernel, inputs, workspace, outputs);
ret = DoLaunchKernel(kernel, inputs, workspace, outputs, stream);
#ifndef ENABLE_SECURITY
} else {
std::lock_guard<std::mutex> locker(launch_mutex_);
auto lock = LockLaunchKernel(stream);
MS_LOG(DEBUG) << "Launch kernel: " << kernel->fullname_with_scope();
ret = LaunchKernelWithProfiling(kernel, inputs, workspace, outputs);
ret = LaunchKernelWithProfiling(kernel, inputs, workspace, outputs, stream);
}
#endif
if (!ret) {
@ -496,8 +554,9 @@ bool GPUDeviceContext::LaunchKernel(const CNodePtr &kernel, const std::vector<Ad
#ifndef ENABLE_SECURITY
bool GPUDeviceContext::LaunchKernelWithProfiling(const CNodePtr &kernel, const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) const {
const std::vector<AddressPtr> &outputs, void *stream) const {
MS_EXCEPTION_IF_NULL(kernel);
MS_EXCEPTION_IF_NULL(stream);
auto kernel_graph = std::dynamic_pointer_cast<KernelGraph>(kernel->func_graph());
MS_EXCEPTION_IF_NULL(kernel_graph);
@ -512,7 +571,7 @@ bool GPUDeviceContext::LaunchKernelWithProfiling(const CNodePtr &kernel, const s
}
profiler_inst->OpDataProducerBegin(kernel->fullname_with_scope(), streams_.front());
bool ret = DoLaunchKernel(kernel, inputs, workspace, outputs);
bool ret = DoLaunchKernel(kernel, inputs, workspace, outputs, stream);
profiler_inst->OpDataProducerEnd();
profiler_inst->RecordFrameWorkInfo(kernel);
@ -527,8 +586,16 @@ bool GPUDeviceContext::LaunchKernelWithProfiling(const CNodePtr &kernel, const s
}
#endif
bool GPUDeviceContext::DoLaunchKernel(const CNodePtr &kernel, const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) const {
const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs,
void *stream) const {
MS_EXCEPTION_IF_NULL(kernel);
MS_EXCEPTION_IF_NULL(stream);
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
MS_EXCEPTION_IF_NULL(kernel_mod);
return kernel_mod->Launch(inputs, workspace, outputs, stream);
}
void *GPUDeviceContext::GetLaunchKernelStream(const CNodePtr &kernel) const {
void *stream = nullptr;
if (common::AnfAlgo::HasNodeAttr(kAttrStream, kernel)) {
auto stream_id = common::AnfAlgo::GetNodeAttr<size_t>(kernel, kAttrStream);
@ -542,9 +609,7 @@ bool GPUDeviceContext::DoLaunchKernel(const CNodePtr &kernel, const std::vector<
}
MS_EXCEPTION_IF_NULL(stream);
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
MS_EXCEPTION_IF_NULL(kernel_mod);
return kernel_mod->Launch(inputs, workspace, outputs, stream);
return stream;
}
bool GPUDeviceContext::SyncStream(size_t stream_id) const {

View File

@ -92,12 +92,17 @@ class GPUDeviceContext : public DeviceContext {
#ifndef ENABLE_SECURITY
// Launch a kernel and record the elapsed time end to end.
bool LaunchKernelWithProfiling(const CNodePtr &kernel, const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) const;
const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs,
void *stream) const;
#endif
// Launch a kernel by 'KernelMod' of the kernel.
bool DoLaunchKernel(const CNodePtr &kernel, const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs) const;
const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs,
void *stream) const;
// Get the used to launch kernel, if there is a stream saved in attrs of kernel, use this stream, otherwise use
// default stream.
void *GetLaunchKernelStream(const CNodePtr &kernel) const;
// Really create a cuda stream.
bool CreateStream(void **stream) const override;
@ -105,10 +110,6 @@ class GPUDeviceContext : public DeviceContext {
// Really destroy a cuda stream.
bool DestroyStream(void *stream) const override;
// The cublas handle is not thread safety specifically, it is not recommended that multiple threads access the same
// cublas handle at the same time, so need the launch mutex when multiple threads launch the cublas kernels.
mutable std::mutex launch_mutex_;
std::shared_ptr<MemoryManager> mem_manager_;
std::vector<void *> streams_;
bool initialized_;

View File

@ -22,6 +22,11 @@ namespace ps {
const size_t kTimeoutLoopCount = 40;
const int64_t kLongestTimeToWait = 30;
PsDataPrefetch &PsDataPrefetch::GetInstance() {
static PsDataPrefetch instance;
return instance;
}
void PsDataPrefetch::CreateDataChannel(const std::string &channel_name, size_t step_num) {
if (cache_enable_ == false) {
return;

View File

@ -29,10 +29,7 @@ namespace mindspore {
namespace ps {
class EXPORT PsDataPrefetch {
public:
EXPORT static PsDataPrefetch &GetInstance() {
static PsDataPrefetch instance;
return instance;
}
EXPORT static PsDataPrefetch &GetInstance();
EXPORT bool cache_enable() const { return cache_enable_; }
EXPORT void set_cache_enable(bool cache_enable) { cache_enable_ = cache_enable; }

View File

@ -19,6 +19,7 @@
#include "runtime/graph_scheduler/device_tensor_store.h"
#include "utils/ms_context.h"
#include "include/common/utils/anfalgo.h"
#include "ps/ps_context.h"
namespace mindspore {
namespace runtime {
@ -391,5 +392,9 @@ std::set<size_t> FetchModifiableRefOutputIndex(const CNodePtr &cnode, const Kern
}
return ref_output_indexes;
}
bool is_embedding_cache_server() {
return ps::PSContext::instance()->cache_enable() && ps::PSContext::instance()->is_server();
}
} // namespace runtime
} // namespace mindspore

View File

@ -285,6 +285,9 @@ std::string FetchActorName(KernelTransformType kernel_type, const std::string &a
std::set<size_t> FetchModifiableRefInputIndex(const CNodePtr &node);
// Fetch the output indexes which may be modified that exist in the ref node.
std::set<size_t> FetchModifiableRefOutputIndex(const CNodePtr &node, const KernelGraphPtr &graph);
// Check whether this process is parameter server and enable embedding cache.
bool is_embedding_cache_server();
} // namespace runtime
} // namespace mindspore

View File

@ -22,6 +22,7 @@
namespace mindspore {
namespace runtime {
using distributed::cluster::ClusterContext;
using mindspore::session::KernelGraph;
// One and two dimensional shape placeholder.
@ -58,7 +59,7 @@ bool InferOpShape(const CNodePtr &kernel) {
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
MS_EXCEPTION_IF_NULL(kernel_mod);
if (!kernel_mod->Resize(args->op, args->inputs, args->outputs, args->depend_tensor_map)) {
if (kernel::KRET_OK != kernel_mod->Resize(args->op, args->inputs, args->outputs, args->depend_tensor_map)) {
MS_LOG(ERROR) << "Kernel " << kernel->fullname_with_scope() << " resize failed.";
return false;
}
@ -88,14 +89,17 @@ SendRecvPair CreateSenderReceiverPair(uint32_t worker_rank, uint32_t server_rank
int32_t param_key) {
// Create sender and receiver pair.
ReceiverPtr receiver = std::make_shared<Receiver>();
SenderPtr sender = std::make_shared<Sender>(receiver);
MS_EXCEPTION_IF_NULL(receiver);
SenderPtr sender = std::make_shared<Sender>();
MS_EXCEPTION_IF_NULL(sender);
sender->set_receiver(receiver);
// Set inter process edge
receiver->set_inter_process_edge_name(GenerateInterProcessEdge(distributed::kEnvRoleOfServer, server_rank,
receiver->set_inter_process_edge_name(GenerateInterProcessEdge(distributed::kEnvRoleOfPServer, server_rank,
distributed::kEnvRoleOfWorker, worker_rank,
cache_operation, param_key));
sender->set_inter_process_edge_name(GenerateInterProcessEdge(distributed::kEnvRoleOfWorker, worker_rank,
distributed::kEnvRoleOfServer, server_rank,
distributed::kEnvRoleOfPServer, server_rank,
distributed::kLookupEmbeddingCache, param_key));
// Set route table proxy.
@ -172,9 +176,24 @@ void EmbeddingCachePrefetchActor::Initialize() {
MS_LOG(EXCEPTION) << "Create stream failed.";
}
// Get embedding cache table info.
hash_tables_ = embedding_cache_table_manager.hash_tables_;
local_host_cache_size_ = embedding_cache_table_manager.host_cache_size_;
vocab_size_ = embedding_cache_table_manager.vocab_size_;
embedding_device_cache_ = embedding_cache_table_manager.embedding_device_cache_;
MS_EXCEPTION_IF_NULL(embedding_device_cache_);
embedding_host_cache_ = embedding_cache_table_manager.embedding_host_cache_;
MS_EXCEPTION_IF_NULL(embedding_host_cache_);
local_embedding_slice_bounds_ = embedding_cache_table_manager.local_embedding_slice_bounds_;
local_device_cache_bounds_ = embedding_cache_table_manager.local_device_cache_bounds_;
// Get the id range of each server's embedding table slice.
GetRemoteEmbeddingSliceBound();
BuildEmbeddingCacheLookupKernel();
BuildEmbeddingCacheUpdateKernel();
// Build and link rpc operators.
BuildRpcOperators();
LinkRpcOperators();
}
@ -326,8 +345,40 @@ bool EmbeddingCachePrefetchActor::UpdateDeviceCache(void *indices, void *update_
return true;
}
void EmbeddingCachePrefetchActor::IncreaseGraphStep(const std::string &channel_name) {
if (!running_) {
MS_LOG(EXCEPTION) << "PS embedding cache data processing thread isn't running.";
}
if (graph_step_ >= UINT64_MAX) {
MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") will exceed the maximum value of uint64_t.";
}
if (graph_step_ == 0) {
MS_LOG(INFO) << "Graph running waiting embedding table init begin:" << finish_init_parameters_on_remote_;
std::unique_lock<std::mutex> locker(data_mutex_);
data_parser_.wait(locker, [this] { return ((finish_init_parameters_on_remote_ == true) || (running_ == false)); });
if (!running_) {
MS_LOG(EXCEPTION) << "PS embedding cache data processing thread isn't running.";
}
MS_LOG(INFO) << "Graph running waiting embedding table init end.";
}
graph_step_++;
set_channel_name(channel_name);
if (!PsDataPrefetch::GetInstance().TryWakeChannel(channel_name)) {
MS_LOG(EXCEPTION) << "TryWakeChannel failed, channel name: " << channel_name;
}
data_parser_.notify_one();
}
void EmbeddingCachePrefetchActor::Run() {
// Note:Need to wait data channel ready.
running_ = true;
// Wait initialize parameters on remote.
// Prevents the subsequent prefetch cache from failing due to the long initialization time of the large parameter on
// the remote side.
WaitInitParametersOnRemote();
// Wait data channel ready.
WaitDataChannelInit();
MS_LOG(INFO) << "Begin prefetching cache.";
while (running_) {
@ -728,7 +779,6 @@ bool EmbeddingCachePrefetchActor::PushCacheFromDeviceToLocalHost(const HashTable
}
MS_ERROR_IF_NULL(embedding_device_cache_);
MS_ERROR_IF_NULL(embedding_device_cache_->cache_);
MS_ERROR_IF_NULL(embedding_host_cache_);
auto device_cache_device_to_host_index = embedding_device_cache_->device_to_host_index.get();
@ -798,7 +848,6 @@ bool EmbeddingCachePrefetchActor::PullCacheFromLocalHostToDevice(const HashTable
}
MS_ERROR_IF_NULL(embedding_device_cache_);
MS_ERROR_IF_NULL(embedding_device_cache_->cache_);
MS_ERROR_IF_NULL(embedding_host_cache_);
auto host_cache_host_to_device_index = embedding_host_cache_->host_to_device_index.get();
@ -1126,7 +1175,6 @@ bool EmbeddingCachePrefetchActor::SendToRemote(const std::string &cache_operatio
size_t server_rank_id, size_t embedding_dim, const void *keys,
size_t keys_len, const void *values, size_t values_len) {
MS_ERROR_IF_NULL(keys);
MS_ERROR_IF_NULL(values);
// Find sender corresponding to cache operation and parameter key.
auto iter = rpc_operators_.find(cache_operation);
if (iter == rpc_operators_.end()) {
@ -1138,8 +1186,26 @@ bool EmbeddingCachePrefetchActor::SendToRemote(const std::string &cache_operatio
MS_ERROR_IF_NULL(sender);
int64_t ids_num = SizeToLong(keys_len / sizeof(int));
std::vector<ShapeVector> shapes = {{ids_num}, {ids_num, SizeToLong(embedding_dim)}, {1}};
std::vector<TypeId> data_types = {kNumberTypeInt32, kNumberTypeFloat32, kNumberTypeInt64};
ShapeVector ids_shape = {ids_num};
ShapeVector values_shape;
float fake_value = 0.0;
if (values == nullptr && values_len == 0) {
values_shape = {1, 1};
values = &fake_value;
values_len = sizeof(fake_value);
} else {
MS_EXCEPTION_IF_ZERO("embedding_dim", embedding_dim);
int64_t embed_vec_num = SizeToLong(values_len / sizeof(float) / embedding_dim);
if (embed_vec_num != ids_num) {
MS_LOG(EXCEPTION) << "The embedding vector number[" << embed_vec_num << "] shouled be equal to ids number["
<< ids_num << "] which will be send to remote.";
}
values_shape = {embed_vec_num, SizeToLong(embedding_dim)};
}
std::vector<ShapeVector> shapes = {ids_shape, values_shape, {static_cast<int64_t>(1)}};
std::vector<TypeId> data_types = {kNumberTypeInt32, kNumberTypeFloat32, kNumberTypeInt32};
int32_t service_id = GetCacheOpsServiceId(cache_operation, param_key);
AddressPtrList data_list = {std::make_shared<Address>(const_cast<void *>(keys), keys_len),
@ -1220,6 +1286,39 @@ bool EmbeddingCachePrefetchActor::RetrieveEmbeddings(
return true;
}
std::string EmbeddingCachePrefetchActor::channel_name() {
std::lock_guard<std::mutex> locker(channel_mutex_);
return channel_name_;
}
void EmbeddingCachePrefetchActor::set_channel_name(const std::string channel_name) {
if (channel_name_ == channel_name) {
return;
}
std::lock_guard<std::mutex> locker(channel_mutex_);
channel_name_ = channel_name;
}
void EmbeddingCachePrefetchActor::WaitDataChannelInit() {
MS_LOG(INFO) << "Begin wait embedding cache data channel init.";
auto channel = channel_name();
if (channel.empty()) {
std::unique_lock<std::mutex> locker(data_mutex_);
data_parser_.wait(locker, [this] { return !channel_name_.empty() || running_ == false; });
if (!running_) {
return;
}
}
MS_LOG(INFO) << "End wait embedding cache data channel init.";
}
void EmbeddingCachePrefetchActor::WaitInitParametersOnRemote() {
std::unique_lock<std::mutex> locker(data_mutex_);
// Note: wait to finish embedding lookup from remote.
finish_init_parameters_on_remote_ = true;
data_parser_.notify_one();
}
void EmbeddingCachePrefetchActor::BuildRpcOperators() {
// The cache operation support LookupEmbeddingCache and UpdateEmbeddingCache currently.
for (const auto &cache_op : distributed::kEmbeddingCacheOps) {

View File

@ -30,19 +30,12 @@
#include "distributed/rpc/tcp/tcp_client.h"
#include "distributed/rpc/tcp/tcp_server.h"
#include "utils/hash_map.h"
#include "distributed/embedding_cache/embedding_cache_utils.h"
// Note: After the code in ps/ps_cache are removed into runtime/addons/embedding_cache/,
// the follow include file and using declaration of ps will be removed.
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
#include "ps/ps_cache/embedding_hash_map.h"
#include "ps/ps_cache/ps_cache_manager.h"
#include "ps/ps_context.h"
using mindspore::ps::EmbeddingDeviceCache;
using mindspore::ps::EmbeddingHostCache;
using mindspore::ps::HashTableInfo;
using mindspore::ps::INVALID_INDEX_VALUE;
using mindspore::ps::INVALID_STEP_VALUE;
using mindspore::ps::PsCacheStatisticsInfo;
using mindspore::ps::PSContext;
using mindspore::ps::PsDataPrefetch;
@ -59,8 +52,17 @@ using ReceiverPtr = std::shared_ptr<Receiver>;
using SendRecvPair = std::pair<SenderPtr, ReceiverPtr>;
using SendRecvPairList = std::vector<SendRecvPair>;
using distributed::EmbeddingCacheStatisticsInfo;
using distributed::EmbeddingDeviceCache;
using distributed::EmbeddingHostCache;
using distributed::HashTableInfo;
using distributed::INVALID_INDEX_VALUE;
using distributed::INVALID_STEP_VALUE;
using distributed::cluster::ActorRouteTableProxy;
using distributed::cluster::ActorRouteTableProxyPtr;
using distributed::rpc::TCPClient;
using distributed::rpc::TCPServer;
// The EmbeddingCachePrefetchActor is used to cache large embedding table scenarios. The cache level is: Device
// Cache->Local Host Cache->Remote Cache. This Actor is used to perform Local and Device Cache hit analysis and cache
@ -82,6 +84,9 @@ class EmbeddingCachePrefetchActor : public ActorBase {
// Perform local cache hit analysis, prefetch the feature vector corresponding to the next batch into the cache.
void Run();
// Increase the global step of compute graph.
void IncreaseGraphStep(const std::string &channel_name);
// Finalize embedding cache prefetch actor and push latest embedding from local cache to remote cache.
void Finalize();
@ -206,6 +211,19 @@ class EmbeddingCachePrefetchActor : public ActorBase {
bool UpdateDeviceCache(void *indices, void *update_value, size_t indices_num, size_t cache_size,
size_t embedding_size, void *embedding_cache);
// Get dataset channel name.
std::string channel_name();
// Set dataset channel name.
void set_channel_name(const std::string channel_name);
// Wait data channel ready.
void WaitDataChannelInit();
// Wait initialize parameters on remote.
// Prevents the subsequent prefetch cache from failing due to the long initialization time of the large parameter on
// the remote side.
void WaitInitParametersOnRemote();
// Record sender and receiver pairs for different cache operation, server and parameter key.
// key: cache operation(such as LookupEmbeddingCache and UpdateEmbeddingCache)
// value: sender and receiver pairs for this kind of cache operation.
@ -239,7 +257,7 @@ class EmbeddingCachePrefetchActor : public ActorBase {
std::shared_ptr<EmbeddingHostCache> embedding_host_cache_;
// Statistics on the cache hit rate of the host and device and the information used to update cache.
PsCacheStatisticsInfo statistics_info_;
EmbeddingCacheStatisticsInfo statistics_info_;
// Model parallelism is used between multiple workers, and local_embedding_slice_bounds_ records the feature range
// corresponding to the embedding table slice of the process.
@ -255,7 +273,7 @@ class EmbeddingCachePrefetchActor : public ActorBase {
std::vector<std::pair<size_t, size_t>> remote_embedding_slice_bounds_;
// Total server number of cluster.
size_t server_num_;
size_t server_num_{0};
// The flag which indicates whether this actor is running to prefetch cache.
std::atomic_bool running_{false};
@ -269,6 +287,11 @@ class EmbeddingCachePrefetchActor : public ActorBase {
// Dataset channel name, used in dataset switching scenarios.
std::string channel_name_;
// The mutex to access channel_name_.
std::mutex channel_mutex_;
// The flag indicates whether finish initializing parameters on remote..
std::atomic_bool finish_init_parameters_on_remote_{false};
// Data parser condition variable for prefetching cache, used to start and synchronize intermediate state for cache
// prefetching.
@ -311,13 +334,16 @@ class RpcOperator {
// Sender is used to send data to other process.
class Sender : public RpcOperator {
public:
explicit Sender(const ReceiverPtr &receiver) : server_url_(""), client_(nullptr), receiver_(receiver) {}
Sender() : server_url_(""), client_(nullptr) {}
~Sender();
// Send buffer to peer.
bool Send(const std::vector<ShapeVector> &shapes, const std::vector<TypeId> data_types,
const AddressPtrList &data_list) const;
// Set the receiver paired with the sender to get the 'from url' from the receiver.
void set_receiver(const ReceiverPtr &receiver) { receiver_ = receiver; }
// Lookup peer receiver's route and build network connection.
bool ConnectServer();

View File

@ -15,10 +15,10 @@
*/
#include "runtime/graph_scheduler/embedding_cache_scheduler.h"
#include "runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.h"
#include <string>
#include <memory>
#include "runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.h"
#include "utils/ms_context.h"
namespace mindspore {
namespace runtime {