bugfix for embedding cache
This commit is contained in:
parent
7fe823582b
commit
f78d645b69
|
@ -28,6 +28,7 @@
|
|||
"mindspore/mindspore/core/load_mindir/anf_model_parser.cc" "stlIfStrFind"
|
||||
"mindspore/mindspore/ccsrc/frontend/optimizer/ad/kpynative.cc" "containerOutOfBounds"
|
||||
"mindspore/mindspore/ccsrc/pipeline/jit/action.cc" "unreadVariable"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/gpu/hal/hardware/gpu_device_context.c" "unreadVariable"
|
||||
|
||||
# MindData
|
||||
"mindspore/mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.cc" "useStlAlgorithm"
|
||||
|
|
|
@ -513,6 +513,10 @@ bool Connection::ParseMessage() {
|
|||
total_recv_len -= recvLen;
|
||||
return false;
|
||||
}
|
||||
if (!SetUrlForRecvMessage()) {
|
||||
MS_LOG(ERROR) << "Set url info for recv message failed.";
|
||||
return false;
|
||||
}
|
||||
recv_state = State::kMsgHeader;
|
||||
break;
|
||||
default:
|
||||
|
@ -521,6 +525,24 @@ bool Connection::ParseMessage() {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool Connection::SetUrlForRecvMessage() {
|
||||
auto recv_from_separator_pos = recv_from.find('@');
|
||||
auto recv_to_separator_pos = recv_to.find('@');
|
||||
if (recv_from_separator_pos == std::string::npos && recv_to_separator_pos == std::string::npos) {
|
||||
MS_LOG(ERROR) << "Invalid message format, can not find separator '@'";
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string from_name = recv_from.substr(0, recv_from_separator_pos);
|
||||
std::string from_url = recv_from.substr(recv_from_separator_pos + 1);
|
||||
std::string to_name = recv_to.substr(0, recv_to_separator_pos);
|
||||
std::string to_url = recv_to.substr(recv_to_separator_pos + 1);
|
||||
recv_message->from = AID(from_name, from_url);
|
||||
recv_message->to = AID(to_name, to_url);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void Connection::ReorderHeader(MessageHeader *header) const {
|
||||
header->name_len = ntohl(header->name_len);
|
||||
header->to_len = ntohl(header->to_len);
|
||||
|
|
|
@ -204,6 +204,9 @@ struct Connection {
|
|||
// Parse message from socket recv buffer.
|
||||
bool ParseMessage();
|
||||
|
||||
// After ParseMessage, set from url and to url into recv message.
|
||||
bool SetUrlForRecvMessage();
|
||||
|
||||
// Make a http message based on given input message.
|
||||
std::string GenerateHttpMessage(MessageBase *msg);
|
||||
|
||||
|
|
|
@ -65,6 +65,26 @@ void PsEmbeddingCacheInserter::GetEmbeddingLookupNodes() {
|
|||
});
|
||||
}
|
||||
|
||||
void PsEmbeddingCacheInserter::GetCacheEnableParameters() {
|
||||
MS_EXCEPTION_IF_NULL(root_graph_);
|
||||
const std::vector<AnfNodePtr> ¶meters = root_graph_->parameters();
|
||||
auto params_size = parameters.size();
|
||||
for (size_t i = 0; i < params_size; ++i) {
|
||||
MS_EXCEPTION_IF_NULL(parameters[i]);
|
||||
if (!parameters[i]->isa<Parameter>()) {
|
||||
MS_LOG(EXCEPTION) << "The node with name: " << parameters[i]->fullname_with_scope() << "is not a Parameter.";
|
||||
}
|
||||
|
||||
ParameterPtr param = parameters[i]->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
auto param_info = param->param_info();
|
||||
if (param_info && param_info->key() != -1 && param_info->cache_enable()) {
|
||||
keys_to_params_[param_info->key()] = param;
|
||||
MS_LOG(INFO) << "Parameter[" << param->fullname_with_scope() << "], key[" << param_info->key() << "]";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PsEmbeddingCacheInserter::SetNodeAttr(const CNodePtr &node, const std::string &node_role) const {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
|
@ -107,7 +127,7 @@ void PsEmbeddingCacheInserter::SetSendNodeAttr(const CNodePtr &send_node, int32_
|
|||
dst_ranks.push_back(i);
|
||||
dst_roles.push_back(dst_role);
|
||||
// Unique edge name: src role + src rank id -> dst role + dst rank id +embedding cache operation + parameter key.
|
||||
inter_process_edges.push_back(distributed::kEnvRoleOfServer + std::to_string(rank_id_) + "->" + dst_role +
|
||||
inter_process_edges.push_back(distributed::kEnvRoleOfPServer + std::to_string(rank_id_) + "->" + dst_role +
|
||||
std::to_string(i) + "_" + embedding_cache_op + "_" + distributed::kParameterKey +
|
||||
std::to_string(param_key));
|
||||
}
|
||||
|
@ -118,6 +138,7 @@ void PsEmbeddingCacheInserter::SetSendNodeAttr(const CNodePtr &send_node, int32_
|
|||
common::AnfAlgo::SetNodeAttr(kAttrSendDstNodeName, MakeValue(std::string(kEmbeddingLocalCacheNode)), send_node);
|
||||
|
||||
common::AnfAlgo::SetNodeAttr(kAttrInterProcessEdgeNames, MakeValue(inter_process_edges), send_node);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrIsMuxRpcKernel, MakeValue(true), send_node);
|
||||
}
|
||||
|
||||
void PsEmbeddingCacheInserter::SetRecvNodeAttr(const CNodePtr &recv_node, const std::string &src_role) const {
|
||||
|
@ -139,7 +160,7 @@ void PsEmbeddingCacheInserter::SetRecvNodeAttr(const CNodePtr &recv_node, const
|
|||
src_roles.push_back(src_role);
|
||||
// Unique edge name: src role + src rank id -> dst role + dst rank id + embedding cache operation + parameter
|
||||
// key.
|
||||
inter_process_edges.push_back(src_role + std::to_string(i) + "->" + distributed::kEnvRoleOfServer +
|
||||
inter_process_edges.push_back(src_role + std::to_string(i) + "->" + distributed::kEnvRoleOfPServer +
|
||||
std::to_string(rank_id_) + "_" + distributed::kEmbeddingCacheOps[k] + "_" +
|
||||
distributed::kParameterKey + std::to_string(param_key));
|
||||
}
|
||||
|
@ -150,7 +171,9 @@ void PsEmbeddingCacheInserter::SetRecvNodeAttr(const CNodePtr &recv_node, const
|
|||
common::AnfAlgo::SetNodeAttr(kAttrRecvSrcRoles, MakeValue(src_roles), recv_node);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrRecvSrcNodeName, MakeValue(std::string(kEmbeddingLocalCacheNode)), recv_node);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrRecvDstNodeName, MakeValue(std::string(kEmbeddingRemoteCacheNode)), recv_node);
|
||||
|
||||
common::AnfAlgo::SetNodeAttr(kAttrInterProcessEdgeNames, MakeValue(inter_process_edges), recv_node);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrIsMuxRpcKernel, MakeValue(true), recv_node);
|
||||
}
|
||||
|
||||
CNodePtr PsEmbeddingCacheInserter::CreateReturnNode(const FuncGraphPtr graph, const AnfNodePtr &output_node) const {
|
||||
|
@ -421,6 +444,9 @@ bool PsEmbeddingCacheInserter::Run() {
|
|||
// Get EmbeddingLookup nodes which are executed on server from origin function graph.
|
||||
GetEmbeddingLookupNodes();
|
||||
|
||||
// Get parameters enabled embedding cache of origin function graph.
|
||||
GetCacheEnableParameters();
|
||||
|
||||
// Construct the embedding cache graph of server.
|
||||
RETURN_IF_FALSE_WITH_LOG(ConstructEmbeddingCacheGraph(), "Construct embedding cache graph failed.");
|
||||
|
||||
|
|
|
@ -21,7 +21,6 @@
|
|||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
#include "utils/hash_map.h"
|
||||
#include "ir/anf.h"
|
||||
#include "distributed/constants.h"
|
||||
|
||||
|
@ -104,7 +103,7 @@ class PsEmbeddingCacheInserter {
|
|||
|
||||
// Record parameters enabled embedding cache of origin function graph.
|
||||
// Key: parameter key, Value: ParameterPtr
|
||||
mindspore::HashMap<int32_t, ParameterPtr> keys_to_params_;
|
||||
std::map<int32_t, ParameterPtr> keys_to_params_;
|
||||
|
||||
// Record EmbeddingLookup nodes which are executed on server from origin function graph.
|
||||
// Key: shape of EmbeddingLookup node, Value: EmbeddingLookup AnfNodePtr.
|
||||
|
|
|
@ -580,6 +580,7 @@ constexpr auto kAttrRecvSrcRanks = "recv_src_ranks";
|
|||
constexpr auto kAttrRecvSrcRoles = "recv_src_roles";
|
||||
constexpr auto kAttrInterProcessEdgeNames = "inter_process_edge_names";
|
||||
constexpr auto kAttrInterProcessEdgeLabel = "inter_process_edge_label";
|
||||
constexpr auto kAttrIsMuxRpcKernel = "is_mux_rpc_kernel";
|
||||
constexpr auto kAttrForwardOpOutputId = "forward_op_output_id";
|
||||
constexpr auto kAttrGroupRankIds = "group_rank_ids";
|
||||
constexpr auto kAttrReuseCommunication = "reuse_communication_node";
|
||||
|
|
|
@ -1423,6 +1423,12 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc
|
|||
MS_EXCEPTION_IF_NULL(backend);
|
||||
// The data set graph compiling and running of mindRT.
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
|
||||
#ifdef WITH_BACKEND
|
||||
if (ps::PSContext::instance()->is_worker() && ps::PSContext::instance()->cache_enable()) {
|
||||
ps::PsDataPrefetch::GetInstance().CreateDataChannel(queue_name, LongToSize(size));
|
||||
}
|
||||
#endif
|
||||
|
||||
const auto &mindrt_backend = std::dynamic_pointer_cast<compile::MindRTBackend>(backend);
|
||||
MS_EXCEPTION_IF_NULL(mindrt_backend);
|
||||
auto &actor_info = mindrt_backend->CompileGraphs(func_graph);
|
||||
|
|
|
@ -361,6 +361,61 @@ void SetKernelInfoBeforeCreateKernel(const std::vector<CNodePtr> &nodes) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check whether mutex exists for a stream.
|
||||
std::pair<bool, std::mutex *> CheckStreamMutexExist(
|
||||
const void *stream, const mindspore::HashMap<const void *, std::shared_ptr<std::mutex>> &mtxs_for_streams,
|
||||
std::shared_mutex *shd_mtx) {
|
||||
MS_EXCEPTION_IF_NULL(stream);
|
||||
MS_EXCEPTION_IF_NULL(shd_mtx);
|
||||
std::shared_lock<std::shared_mutex> shd_lock(*shd_mtx);
|
||||
auto iter = mtxs_for_streams.find(stream);
|
||||
if (iter != mtxs_for_streams.end()) {
|
||||
MS_EXCEPTION_IF_NULL(iter->second);
|
||||
return std::make_pair(true, iter->second.get());
|
||||
}
|
||||
return std::make_pair(false, nullptr);
|
||||
}
|
||||
|
||||
// Create a mutex for stream.
|
||||
std::mutex *CreateStreamMutex(const void *stream, std::shared_mutex *shd_mtx,
|
||||
mindspore::HashMap<const void *, std::shared_ptr<std::mutex>> *mtxs_for_streams) {
|
||||
MS_EXCEPTION_IF_NULL(stream);
|
||||
MS_EXCEPTION_IF_NULL(shd_mtx);
|
||||
MS_EXCEPTION_IF_NULL(mtxs_for_streams);
|
||||
|
||||
std::unique_lock<std::shared_mutex> unq_lock(*shd_mtx);
|
||||
auto ret_pair = mtxs_for_streams->emplace(stream, std::make_shared<std::mutex>());
|
||||
|
||||
MS_EXCEPTION_IF_NULL(ret_pair.first->second);
|
||||
return ret_pair.first->second.get();
|
||||
}
|
||||
|
||||
// The launch kernel is thread-unsafe, and the behavior of delivering the kernel launch to the same stream requires
|
||||
// lock protection, need to create a separate lock for each stream.
|
||||
// for GPU, The cublas handle is not thread safety specifically, it is not recommended that multiple threads access the
|
||||
// same cublas handle at the same time, so need the launch mutex when multiple threads launch the cublas kernels.
|
||||
std::lock_guard<std::mutex> LockLaunchKernel(const void *stream) {
|
||||
MS_EXCEPTION_IF_NULL(stream);
|
||||
// Read-write lock for accessing mtxs_for_streams map.
|
||||
// When the lock of each stream is created, mtxs_for_streams can be accessed concurrently to improve performance.
|
||||
static std::shared_mutex shd_mtx;
|
||||
static mindspore::HashMap<const void *, std::shared_ptr<std::mutex>> mtxs_for_streams;
|
||||
|
||||
std::mutex *stream_mtx;
|
||||
// Check whether mutex exists for a stream.
|
||||
std::pair<bool, std::mutex *> ret_pair = CheckStreamMutexExist(stream, mtxs_for_streams, &shd_mtx);
|
||||
if (ret_pair.first) {
|
||||
stream_mtx = ret_pair.second;
|
||||
} else {
|
||||
// Create a mutex for stream.
|
||||
stream_mtx = CreateStreamMutex(stream, &shd_mtx, &mtxs_for_streams);
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(stream_mtx);
|
||||
// Lock kernel launch for the stream.
|
||||
return std::lock_guard<std::mutex>(*stream_mtx);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void GPUDeviceContext::OptimizeGraph(const FuncGraphPtr &graph) const {
|
||||
|
@ -462,20 +517,23 @@ bool GPUDeviceContext::LaunchKernel(const CNodePtr &kernel, const std::vector<Ad
|
|||
}
|
||||
bool ret = true;
|
||||
|
||||
auto stream = GetLaunchKernelStream(kernel);
|
||||
MS_EXCEPTION_IF_NULL(stream);
|
||||
|
||||
#ifndef ENABLE_SECURITY
|
||||
const auto &profiler_inst = profiler::gpu::GPUProfiler::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(profiler_inst);
|
||||
|
||||
if (!profiler_inst->GetEnableFlag()) {
|
||||
#endif
|
||||
std::lock_guard<std::mutex> locker(launch_mutex_);
|
||||
auto lock = LockLaunchKernel(stream);
|
||||
MS_LOG(DEBUG) << "Launch kernel: " << kernel->fullname_with_scope();
|
||||
ret = DoLaunchKernel(kernel, inputs, workspace, outputs);
|
||||
ret = DoLaunchKernel(kernel, inputs, workspace, outputs, stream);
|
||||
#ifndef ENABLE_SECURITY
|
||||
} else {
|
||||
std::lock_guard<std::mutex> locker(launch_mutex_);
|
||||
auto lock = LockLaunchKernel(stream);
|
||||
MS_LOG(DEBUG) << "Launch kernel: " << kernel->fullname_with_scope();
|
||||
ret = LaunchKernelWithProfiling(kernel, inputs, workspace, outputs);
|
||||
ret = LaunchKernelWithProfiling(kernel, inputs, workspace, outputs, stream);
|
||||
}
|
||||
#endif
|
||||
if (!ret) {
|
||||
|
@ -496,8 +554,9 @@ bool GPUDeviceContext::LaunchKernel(const CNodePtr &kernel, const std::vector<Ad
|
|||
#ifndef ENABLE_SECURITY
|
||||
bool GPUDeviceContext::LaunchKernelWithProfiling(const CNodePtr &kernel, const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) const {
|
||||
const std::vector<AddressPtr> &outputs, void *stream) const {
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
MS_EXCEPTION_IF_NULL(stream);
|
||||
|
||||
auto kernel_graph = std::dynamic_pointer_cast<KernelGraph>(kernel->func_graph());
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
|
@ -512,7 +571,7 @@ bool GPUDeviceContext::LaunchKernelWithProfiling(const CNodePtr &kernel, const s
|
|||
}
|
||||
|
||||
profiler_inst->OpDataProducerBegin(kernel->fullname_with_scope(), streams_.front());
|
||||
bool ret = DoLaunchKernel(kernel, inputs, workspace, outputs);
|
||||
bool ret = DoLaunchKernel(kernel, inputs, workspace, outputs, stream);
|
||||
profiler_inst->OpDataProducerEnd();
|
||||
profiler_inst->RecordFrameWorkInfo(kernel);
|
||||
|
||||
|
@ -527,8 +586,16 @@ bool GPUDeviceContext::LaunchKernelWithProfiling(const CNodePtr &kernel, const s
|
|||
}
|
||||
#endif
|
||||
bool GPUDeviceContext::DoLaunchKernel(const CNodePtr &kernel, const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) const {
|
||||
const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs,
|
||||
void *stream) const {
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
MS_EXCEPTION_IF_NULL(stream);
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
return kernel_mod->Launch(inputs, workspace, outputs, stream);
|
||||
}
|
||||
|
||||
void *GPUDeviceContext::GetLaunchKernelStream(const CNodePtr &kernel) const {
|
||||
void *stream = nullptr;
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrStream, kernel)) {
|
||||
auto stream_id = common::AnfAlgo::GetNodeAttr<size_t>(kernel, kAttrStream);
|
||||
|
@ -542,9 +609,7 @@ bool GPUDeviceContext::DoLaunchKernel(const CNodePtr &kernel, const std::vector<
|
|||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(stream);
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
return kernel_mod->Launch(inputs, workspace, outputs, stream);
|
||||
return stream;
|
||||
}
|
||||
|
||||
bool GPUDeviceContext::SyncStream(size_t stream_id) const {
|
||||
|
|
|
@ -92,12 +92,17 @@ class GPUDeviceContext : public DeviceContext {
|
|||
#ifndef ENABLE_SECURITY
|
||||
// Launch a kernel and record the elapsed time end to end.
|
||||
bool LaunchKernelWithProfiling(const CNodePtr &kernel, const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) const;
|
||||
const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs,
|
||||
void *stream) const;
|
||||
#endif
|
||||
// Launch a kernel by 'KernelMod' of the kernel.
|
||||
bool DoLaunchKernel(const CNodePtr &kernel, const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs) const;
|
||||
const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs,
|
||||
void *stream) const;
|
||||
|
||||
// Get the used to launch kernel, if there is a stream saved in attrs of kernel, use this stream, otherwise use
|
||||
// default stream.
|
||||
void *GetLaunchKernelStream(const CNodePtr &kernel) const;
|
||||
|
||||
// Really create a cuda stream.
|
||||
bool CreateStream(void **stream) const override;
|
||||
|
@ -105,10 +110,6 @@ class GPUDeviceContext : public DeviceContext {
|
|||
// Really destroy a cuda stream.
|
||||
bool DestroyStream(void *stream) const override;
|
||||
|
||||
// The cublas handle is not thread safety specifically, it is not recommended that multiple threads access the same
|
||||
// cublas handle at the same time, so need the launch mutex when multiple threads launch the cublas kernels.
|
||||
mutable std::mutex launch_mutex_;
|
||||
|
||||
std::shared_ptr<MemoryManager> mem_manager_;
|
||||
std::vector<void *> streams_;
|
||||
bool initialized_;
|
||||
|
|
|
@ -22,6 +22,11 @@ namespace ps {
|
|||
const size_t kTimeoutLoopCount = 40;
|
||||
const int64_t kLongestTimeToWait = 30;
|
||||
|
||||
PsDataPrefetch &PsDataPrefetch::GetInstance() {
|
||||
static PsDataPrefetch instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void PsDataPrefetch::CreateDataChannel(const std::string &channel_name, size_t step_num) {
|
||||
if (cache_enable_ == false) {
|
||||
return;
|
||||
|
|
|
@ -29,10 +29,7 @@ namespace mindspore {
|
|||
namespace ps {
|
||||
class EXPORT PsDataPrefetch {
|
||||
public:
|
||||
EXPORT static PsDataPrefetch &GetInstance() {
|
||||
static PsDataPrefetch instance;
|
||||
return instance;
|
||||
}
|
||||
EXPORT static PsDataPrefetch &GetInstance();
|
||||
|
||||
EXPORT bool cache_enable() const { return cache_enable_; }
|
||||
EXPORT void set_cache_enable(bool cache_enable) { cache_enable_ = cache_enable; }
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include "runtime/graph_scheduler/device_tensor_store.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "ps/ps_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
|
@ -391,5 +392,9 @@ std::set<size_t> FetchModifiableRefOutputIndex(const CNodePtr &cnode, const Kern
|
|||
}
|
||||
return ref_output_indexes;
|
||||
}
|
||||
|
||||
bool is_embedding_cache_server() {
|
||||
return ps::PSContext::instance()->cache_enable() && ps::PSContext::instance()->is_server();
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -285,6 +285,9 @@ std::string FetchActorName(KernelTransformType kernel_type, const std::string &a
|
|||
std::set<size_t> FetchModifiableRefInputIndex(const CNodePtr &node);
|
||||
// Fetch the output indexes which may be modified that exist in the ref node.
|
||||
std::set<size_t> FetchModifiableRefOutputIndex(const CNodePtr &node, const KernelGraphPtr &graph);
|
||||
|
||||
// Check whether this process is parameter server and enable embedding cache.
|
||||
bool is_embedding_cache_server();
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
using distributed::cluster::ClusterContext;
|
||||
using mindspore::session::KernelGraph;
|
||||
|
||||
// One and two dimensional shape placeholder.
|
||||
|
@ -58,7 +59,7 @@ bool InferOpShape(const CNodePtr &kernel) {
|
|||
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
if (!kernel_mod->Resize(args->op, args->inputs, args->outputs, args->depend_tensor_map)) {
|
||||
if (kernel::KRET_OK != kernel_mod->Resize(args->op, args->inputs, args->outputs, args->depend_tensor_map)) {
|
||||
MS_LOG(ERROR) << "Kernel " << kernel->fullname_with_scope() << " resize failed.";
|
||||
return false;
|
||||
}
|
||||
|
@ -88,14 +89,17 @@ SendRecvPair CreateSenderReceiverPair(uint32_t worker_rank, uint32_t server_rank
|
|||
int32_t param_key) {
|
||||
// Create sender and receiver pair.
|
||||
ReceiverPtr receiver = std::make_shared<Receiver>();
|
||||
SenderPtr sender = std::make_shared<Sender>(receiver);
|
||||
MS_EXCEPTION_IF_NULL(receiver);
|
||||
SenderPtr sender = std::make_shared<Sender>();
|
||||
MS_EXCEPTION_IF_NULL(sender);
|
||||
sender->set_receiver(receiver);
|
||||
|
||||
// Set inter process edge
|
||||
receiver->set_inter_process_edge_name(GenerateInterProcessEdge(distributed::kEnvRoleOfServer, server_rank,
|
||||
receiver->set_inter_process_edge_name(GenerateInterProcessEdge(distributed::kEnvRoleOfPServer, server_rank,
|
||||
distributed::kEnvRoleOfWorker, worker_rank,
|
||||
cache_operation, param_key));
|
||||
sender->set_inter_process_edge_name(GenerateInterProcessEdge(distributed::kEnvRoleOfWorker, worker_rank,
|
||||
distributed::kEnvRoleOfServer, server_rank,
|
||||
distributed::kEnvRoleOfPServer, server_rank,
|
||||
distributed::kLookupEmbeddingCache, param_key));
|
||||
|
||||
// Set route table proxy.
|
||||
|
@ -172,9 +176,24 @@ void EmbeddingCachePrefetchActor::Initialize() {
|
|||
MS_LOG(EXCEPTION) << "Create stream failed.";
|
||||
}
|
||||
|
||||
// Get embedding cache table info.
|
||||
hash_tables_ = embedding_cache_table_manager.hash_tables_;
|
||||
local_host_cache_size_ = embedding_cache_table_manager.host_cache_size_;
|
||||
vocab_size_ = embedding_cache_table_manager.vocab_size_;
|
||||
embedding_device_cache_ = embedding_cache_table_manager.embedding_device_cache_;
|
||||
MS_EXCEPTION_IF_NULL(embedding_device_cache_);
|
||||
embedding_host_cache_ = embedding_cache_table_manager.embedding_host_cache_;
|
||||
MS_EXCEPTION_IF_NULL(embedding_host_cache_);
|
||||
local_embedding_slice_bounds_ = embedding_cache_table_manager.local_embedding_slice_bounds_;
|
||||
local_device_cache_bounds_ = embedding_cache_table_manager.local_device_cache_bounds_;
|
||||
|
||||
// Get the id range of each server's embedding table slice.
|
||||
GetRemoteEmbeddingSliceBound();
|
||||
|
||||
BuildEmbeddingCacheLookupKernel();
|
||||
BuildEmbeddingCacheUpdateKernel();
|
||||
|
||||
// Build and link rpc operators.
|
||||
BuildRpcOperators();
|
||||
LinkRpcOperators();
|
||||
}
|
||||
|
@ -326,8 +345,40 @@ bool EmbeddingCachePrefetchActor::UpdateDeviceCache(void *indices, void *update_
|
|||
return true;
|
||||
}
|
||||
|
||||
void EmbeddingCachePrefetchActor::IncreaseGraphStep(const std::string &channel_name) {
|
||||
if (!running_) {
|
||||
MS_LOG(EXCEPTION) << "PS embedding cache data processing thread isn't running.";
|
||||
}
|
||||
if (graph_step_ >= UINT64_MAX) {
|
||||
MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") will exceed the maximum value of uint64_t.";
|
||||
}
|
||||
if (graph_step_ == 0) {
|
||||
MS_LOG(INFO) << "Graph running waiting embedding table init begin:" << finish_init_parameters_on_remote_;
|
||||
std::unique_lock<std::mutex> locker(data_mutex_);
|
||||
data_parser_.wait(locker, [this] { return ((finish_init_parameters_on_remote_ == true) || (running_ == false)); });
|
||||
if (!running_) {
|
||||
MS_LOG(EXCEPTION) << "PS embedding cache data processing thread isn't running.";
|
||||
}
|
||||
MS_LOG(INFO) << "Graph running waiting embedding table init end.";
|
||||
}
|
||||
graph_step_++;
|
||||
set_channel_name(channel_name);
|
||||
if (!PsDataPrefetch::GetInstance().TryWakeChannel(channel_name)) {
|
||||
MS_LOG(EXCEPTION) << "TryWakeChannel failed, channel name: " << channel_name;
|
||||
}
|
||||
data_parser_.notify_one();
|
||||
}
|
||||
|
||||
void EmbeddingCachePrefetchActor::Run() {
|
||||
// Note:Need to wait data channel ready.
|
||||
running_ = true;
|
||||
|
||||
// Wait initialize parameters on remote.
|
||||
// Prevents the subsequent prefetch cache from failing due to the long initialization time of the large parameter on
|
||||
// the remote side.
|
||||
WaitInitParametersOnRemote();
|
||||
|
||||
// Wait data channel ready.
|
||||
WaitDataChannelInit();
|
||||
|
||||
MS_LOG(INFO) << "Begin prefetching cache.";
|
||||
while (running_) {
|
||||
|
@ -728,7 +779,6 @@ bool EmbeddingCachePrefetchActor::PushCacheFromDeviceToLocalHost(const HashTable
|
|||
}
|
||||
|
||||
MS_ERROR_IF_NULL(embedding_device_cache_);
|
||||
MS_ERROR_IF_NULL(embedding_device_cache_->cache_);
|
||||
MS_ERROR_IF_NULL(embedding_host_cache_);
|
||||
|
||||
auto device_cache_device_to_host_index = embedding_device_cache_->device_to_host_index.get();
|
||||
|
@ -798,7 +848,6 @@ bool EmbeddingCachePrefetchActor::PullCacheFromLocalHostToDevice(const HashTable
|
|||
}
|
||||
|
||||
MS_ERROR_IF_NULL(embedding_device_cache_);
|
||||
MS_ERROR_IF_NULL(embedding_device_cache_->cache_);
|
||||
MS_ERROR_IF_NULL(embedding_host_cache_);
|
||||
|
||||
auto host_cache_host_to_device_index = embedding_host_cache_->host_to_device_index.get();
|
||||
|
@ -1126,7 +1175,6 @@ bool EmbeddingCachePrefetchActor::SendToRemote(const std::string &cache_operatio
|
|||
size_t server_rank_id, size_t embedding_dim, const void *keys,
|
||||
size_t keys_len, const void *values, size_t values_len) {
|
||||
MS_ERROR_IF_NULL(keys);
|
||||
MS_ERROR_IF_NULL(values);
|
||||
// Find sender corresponding to cache operation and parameter key.
|
||||
auto iter = rpc_operators_.find(cache_operation);
|
||||
if (iter == rpc_operators_.end()) {
|
||||
|
@ -1138,8 +1186,26 @@ bool EmbeddingCachePrefetchActor::SendToRemote(const std::string &cache_operatio
|
|||
MS_ERROR_IF_NULL(sender);
|
||||
|
||||
int64_t ids_num = SizeToLong(keys_len / sizeof(int));
|
||||
std::vector<ShapeVector> shapes = {{ids_num}, {ids_num, SizeToLong(embedding_dim)}, {1}};
|
||||
std::vector<TypeId> data_types = {kNumberTypeInt32, kNumberTypeFloat32, kNumberTypeInt64};
|
||||
ShapeVector ids_shape = {ids_num};
|
||||
ShapeVector values_shape;
|
||||
float fake_value = 0.0;
|
||||
|
||||
if (values == nullptr && values_len == 0) {
|
||||
values_shape = {1, 1};
|
||||
values = &fake_value;
|
||||
values_len = sizeof(fake_value);
|
||||
} else {
|
||||
MS_EXCEPTION_IF_ZERO("embedding_dim", embedding_dim);
|
||||
int64_t embed_vec_num = SizeToLong(values_len / sizeof(float) / embedding_dim);
|
||||
if (embed_vec_num != ids_num) {
|
||||
MS_LOG(EXCEPTION) << "The embedding vector number[" << embed_vec_num << "] shouled be equal to ids number["
|
||||
<< ids_num << "] which will be send to remote.";
|
||||
}
|
||||
values_shape = {embed_vec_num, SizeToLong(embedding_dim)};
|
||||
}
|
||||
|
||||
std::vector<ShapeVector> shapes = {ids_shape, values_shape, {static_cast<int64_t>(1)}};
|
||||
std::vector<TypeId> data_types = {kNumberTypeInt32, kNumberTypeFloat32, kNumberTypeInt32};
|
||||
|
||||
int32_t service_id = GetCacheOpsServiceId(cache_operation, param_key);
|
||||
AddressPtrList data_list = {std::make_shared<Address>(const_cast<void *>(keys), keys_len),
|
||||
|
@ -1220,6 +1286,39 @@ bool EmbeddingCachePrefetchActor::RetrieveEmbeddings(
|
|||
return true;
|
||||
}
|
||||
|
||||
std::string EmbeddingCachePrefetchActor::channel_name() {
|
||||
std::lock_guard<std::mutex> locker(channel_mutex_);
|
||||
return channel_name_;
|
||||
}
|
||||
|
||||
void EmbeddingCachePrefetchActor::set_channel_name(const std::string channel_name) {
|
||||
if (channel_name_ == channel_name) {
|
||||
return;
|
||||
}
|
||||
std::lock_guard<std::mutex> locker(channel_mutex_);
|
||||
channel_name_ = channel_name;
|
||||
}
|
||||
|
||||
void EmbeddingCachePrefetchActor::WaitDataChannelInit() {
|
||||
MS_LOG(INFO) << "Begin wait embedding cache data channel init.";
|
||||
auto channel = channel_name();
|
||||
if (channel.empty()) {
|
||||
std::unique_lock<std::mutex> locker(data_mutex_);
|
||||
data_parser_.wait(locker, [this] { return !channel_name_.empty() || running_ == false; });
|
||||
if (!running_) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "End wait embedding cache data channel init.";
|
||||
}
|
||||
|
||||
void EmbeddingCachePrefetchActor::WaitInitParametersOnRemote() {
|
||||
std::unique_lock<std::mutex> locker(data_mutex_);
|
||||
// Note: wait to finish embedding lookup from remote.
|
||||
finish_init_parameters_on_remote_ = true;
|
||||
data_parser_.notify_one();
|
||||
}
|
||||
|
||||
void EmbeddingCachePrefetchActor::BuildRpcOperators() {
|
||||
// The cache operation support LookupEmbeddingCache and UpdateEmbeddingCache currently.
|
||||
for (const auto &cache_op : distributed::kEmbeddingCacheOps) {
|
||||
|
|
|
@ -30,19 +30,12 @@
|
|||
#include "distributed/rpc/tcp/tcp_client.h"
|
||||
#include "distributed/rpc/tcp/tcp_server.h"
|
||||
#include "utils/hash_map.h"
|
||||
#include "distributed/embedding_cache/embedding_cache_utils.h"
|
||||
|
||||
// Note: After the code in ps/ps_cache are removed into runtime/addons/embedding_cache/,
|
||||
// the follow include file and using declaration of ps will be removed.
|
||||
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
|
||||
#include "ps/ps_cache/embedding_hash_map.h"
|
||||
#include "ps/ps_cache/ps_cache_manager.h"
|
||||
#include "ps/ps_context.h"
|
||||
using mindspore::ps::EmbeddingDeviceCache;
|
||||
using mindspore::ps::EmbeddingHostCache;
|
||||
using mindspore::ps::HashTableInfo;
|
||||
using mindspore::ps::INVALID_INDEX_VALUE;
|
||||
using mindspore::ps::INVALID_STEP_VALUE;
|
||||
using mindspore::ps::PsCacheStatisticsInfo;
|
||||
using mindspore::ps::PSContext;
|
||||
using mindspore::ps::PsDataPrefetch;
|
||||
|
||||
|
@ -59,8 +52,17 @@ using ReceiverPtr = std::shared_ptr<Receiver>;
|
|||
using SendRecvPair = std::pair<SenderPtr, ReceiverPtr>;
|
||||
using SendRecvPairList = std::vector<SendRecvPair>;
|
||||
|
||||
using distributed::EmbeddingCacheStatisticsInfo;
|
||||
using distributed::EmbeddingDeviceCache;
|
||||
using distributed::EmbeddingHostCache;
|
||||
using distributed::HashTableInfo;
|
||||
using distributed::INVALID_INDEX_VALUE;
|
||||
using distributed::INVALID_STEP_VALUE;
|
||||
|
||||
using distributed::cluster::ActorRouteTableProxy;
|
||||
using distributed::cluster::ActorRouteTableProxyPtr;
|
||||
using distributed::rpc::TCPClient;
|
||||
using distributed::rpc::TCPServer;
|
||||
|
||||
// The EmbeddingCachePrefetchActor is used to cache large embedding table scenarios. The cache level is: Device
|
||||
// Cache->Local Host Cache->Remote Cache. This Actor is used to perform Local and Device Cache hit analysis and cache
|
||||
|
@ -82,6 +84,9 @@ class EmbeddingCachePrefetchActor : public ActorBase {
|
|||
// Perform local cache hit analysis, prefetch the feature vector corresponding to the next batch into the cache.
|
||||
void Run();
|
||||
|
||||
// Increase the global step of compute graph.
|
||||
void IncreaseGraphStep(const std::string &channel_name);
|
||||
|
||||
// Finalize embedding cache prefetch actor and push latest embedding from local cache to remote cache.
|
||||
void Finalize();
|
||||
|
||||
|
@ -206,6 +211,19 @@ class EmbeddingCachePrefetchActor : public ActorBase {
|
|||
bool UpdateDeviceCache(void *indices, void *update_value, size_t indices_num, size_t cache_size,
|
||||
size_t embedding_size, void *embedding_cache);
|
||||
|
||||
// Get dataset channel name.
|
||||
std::string channel_name();
|
||||
// Set dataset channel name.
|
||||
void set_channel_name(const std::string channel_name);
|
||||
|
||||
// Wait data channel ready.
|
||||
void WaitDataChannelInit();
|
||||
|
||||
// Wait initialize parameters on remote.
|
||||
// Prevents the subsequent prefetch cache from failing due to the long initialization time of the large parameter on
|
||||
// the remote side.
|
||||
void WaitInitParametersOnRemote();
|
||||
|
||||
// Record sender and receiver pairs for different cache operation, server and parameter key.
|
||||
// key: cache operation(such as LookupEmbeddingCache and UpdateEmbeddingCache)
|
||||
// value: sender and receiver pairs for this kind of cache operation.
|
||||
|
@ -239,7 +257,7 @@ class EmbeddingCachePrefetchActor : public ActorBase {
|
|||
std::shared_ptr<EmbeddingHostCache> embedding_host_cache_;
|
||||
|
||||
// Statistics on the cache hit rate of the host and device and the information used to update cache.
|
||||
PsCacheStatisticsInfo statistics_info_;
|
||||
EmbeddingCacheStatisticsInfo statistics_info_;
|
||||
|
||||
// Model parallelism is used between multiple workers, and local_embedding_slice_bounds_ records the feature range
|
||||
// corresponding to the embedding table slice of the process.
|
||||
|
@ -255,7 +273,7 @@ class EmbeddingCachePrefetchActor : public ActorBase {
|
|||
std::vector<std::pair<size_t, size_t>> remote_embedding_slice_bounds_;
|
||||
|
||||
// Total server number of cluster.
|
||||
size_t server_num_;
|
||||
size_t server_num_{0};
|
||||
|
||||
// The flag which indicates whether this actor is running to prefetch cache.
|
||||
std::atomic_bool running_{false};
|
||||
|
@ -269,6 +287,11 @@ class EmbeddingCachePrefetchActor : public ActorBase {
|
|||
|
||||
// Dataset channel name, used in dataset switching scenarios.
|
||||
std::string channel_name_;
|
||||
// The mutex to access channel_name_.
|
||||
std::mutex channel_mutex_;
|
||||
|
||||
// The flag indicates whether finish initializing parameters on remote..
|
||||
std::atomic_bool finish_init_parameters_on_remote_{false};
|
||||
|
||||
// Data parser condition variable for prefetching cache, used to start and synchronize intermediate state for cache
|
||||
// prefetching.
|
||||
|
@ -311,13 +334,16 @@ class RpcOperator {
|
|||
// Sender is used to send data to other process.
|
||||
class Sender : public RpcOperator {
|
||||
public:
|
||||
explicit Sender(const ReceiverPtr &receiver) : server_url_(""), client_(nullptr), receiver_(receiver) {}
|
||||
Sender() : server_url_(""), client_(nullptr) {}
|
||||
~Sender();
|
||||
|
||||
// Send buffer to peer.
|
||||
bool Send(const std::vector<ShapeVector> &shapes, const std::vector<TypeId> data_types,
|
||||
const AddressPtrList &data_list) const;
|
||||
|
||||
// Set the receiver paired with the sender to get the 'from url' from the receiver.
|
||||
void set_receiver(const ReceiverPtr &receiver) { receiver_ = receiver; }
|
||||
|
||||
// Lookup peer receiver's route and build network connection.
|
||||
bool ConnectServer();
|
||||
|
||||
|
|
|
@ -15,10 +15,10 @@
|
|||
*/
|
||||
|
||||
#include "runtime/graph_scheduler/embedding_cache_scheduler.h"
|
||||
#include "runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.h"
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
|
|
Loading…
Reference in New Issue