diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.h index cffba873fb..800315e5f3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.h @@ -43,7 +43,10 @@ class PushKernel : public CPUKernel { sizes.push_back(SizeToInt(input->size) / sizeof(T)); } parallel::ps::Worker::GetInstance().Push(keys, addrs, sizes); - memcpy_s(outputs[0]->addr, sizeof(size_t), &key_, sizeof(size_t)); + auto ret = memcpy_s(outputs[0]->addr, sizeof(size_t), &key_, sizeof(size_t)); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "Lookup id memcpy failed."; + } return true; } diff --git a/mindspore/ccsrc/frontend/parallel/ps/common.h b/mindspore/ccsrc/frontend/parallel/ps/common.h index a072a65001..b0d557dc1f 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/common.h +++ b/mindspore/ccsrc/frontend/parallel/ps/common.h @@ -66,6 +66,8 @@ constexpr int kInitWeightToOptimIdCmd = 11; constexpr int kInitOptimInputsShapeCmd = 12; constexpr int kInitKeyToPushNodeIdCmd = 13; constexpr int kInitEmbeddingsCmd = 20; +constexpr int kCheckReadyForPushCmd = 25; +constexpr int kCheckReadyForPullCmd = 26; constexpr int kEmbeddingLookupCmd = 30; constexpr int kFinalizeCmd = 40; diff --git a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc index b87b001696..23ad87c41e 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc +++ b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc @@ -158,16 +158,19 @@ OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight, } AddressPtr linear = std::make_shared(); linear->addr = new float[weight->size()]; - memcpy_s(linear->addr, weight->size() * sizeof(float), 0x00, weight->size() * sizeof(float)); + auto ret = memset_s(linear->addr, weight->size() * sizeof(float), 0x00, weight->size() * sizeof(float)); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + } linear->size = weight->size() * sizeof(float); const std::shared_ptr> &grad_shape = (*inputs_shape)[3]; size_t total_grad_size = std::accumulate((*grad_shape).begin(), (*grad_shape).end(), 1, std::multiplies()); AddressPtr grad = std::make_shared(); grad->addr = new float[total_grad_size * worker_num]; - auto ret = memcpy_s(grad->addr, lens[0] * sizeof(float), values.data(), lens[0] * sizeof(float)); - if (ret != 0) { - MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + auto ret1 = memcpy_s(grad->addr, lens[0] * sizeof(float), values.data(), lens[0] * sizeof(float)); + if (ret1 != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret1 << ")"; } grad->size = lens[0] * sizeof(float); diff --git a/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h b/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h index 64983458df..092e907da0 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h +++ b/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h @@ -91,6 +91,8 @@ class ParameterServer { ::ps::KVPairs *res); void HandleInitInputsShape(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); void HandleInitEmbeddings(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); + void HandleCheckReadyForPush(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); + void HandleCheckReadyForPull(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); void HandleEmbeddingLookup(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); void HandleFinalize(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); @@ -98,6 +100,9 @@ class ParameterServer { typedef void (ServerHandler::*RequestHandler)(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); std::unordered_map handlers_; + std::unordered_map init_weights_; + std::unordered_map init_weight_to_optim_; + std::unordered_map init_optim_info_; }; bool Init(const FuncGraphPtr &func_graph); @@ -115,9 +120,11 @@ class ParameterServer { void DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, ::ps::KVPairs *res); int SumOfShapes(const std::vector &shapes) const; bool ReadyForUpdateWeights(); - bool ReadyForAccumGrads(); + bool ReadyForPush(const Key &key); + bool ReadyForPull(const Key &key); void ResetGradAccumCount(); const CNodePtr GetCNode(const std::string &name) const; + std::mutex &mutex(); size_t pserver_num_; size_t worker_num_; @@ -136,13 +143,14 @@ class ParameterServer { std::unordered_map weight_key_to_optims_; std::unordered_map weight_key_to_optim_op_; std::unordered_map weights_; + std::unordered_map is_embedding_; std::unordered_map grads_; std::unordered_map grads_accum_counter_; std::unordered_map> embedding_lookup_ops_; + std::unordered_map tokens_; std::mutex mutex_; std::condition_variable apply_grads_cv_; - std::condition_variable accum_grads_cv_; std::unique_ptr thread_; @@ -171,6 +179,8 @@ void ParameterServer::ServerHandler::Init() { handlers_[kInitWeightToOptimIdCmd] = &ServerHandler::HandleInitWeightToOptimId; handlers_[kInitOptimInputsShapeCmd] = &ServerHandler::HandleInitInputsShape; handlers_[kInitEmbeddingsCmd] = &ServerHandler::HandleInitEmbeddings; + handlers_[kCheckReadyForPushCmd] = &ServerHandler::HandleCheckReadyForPush; + handlers_[kCheckReadyForPullCmd] = &ServerHandler::HandleCheckReadyForPull; handlers_[kEmbeddingLookupCmd] = &ServerHandler::HandleEmbeddingLookup; handlers_[kFinalizeCmd] = &ServerHandler::HandleFinalize; } @@ -192,11 +202,17 @@ void ParameterServer::ServerHandler::HandlePullReq(const ::ps::KVMeta &req_me template void ParameterServer::ServerHandler::HandleInitWeights(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res) { + std::unique_lock lock(ps_->mutex()); size_t key_num = req_data.keys.size(); T *data_ptr = req_data.vals.data(); size_t pos = 0; for (size_t i = 0; i < key_num; i++) { Key key = req_data.keys[i]; + if (init_weights_[key]) { + continue; + } else { + init_weights_[key] = true; + } size_t data_len = req_data.lens.size() != key_num ? req_data.vals.size() / key_num : req_data.lens[i]; WeightPtr weight_ptr = std::make_shared<::ps::SArray>(); @@ -213,10 +229,16 @@ template void ParameterServer::ServerHandler::HandleInitWeightToOptimId(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res) { + std::unique_lock lock(ps_->mutex()); size_t key_num = req_data.keys.size(); for (size_t i = 0; i < key_num; i++) { Key key = req_data.keys[i]; T val = req_data.vals[i]; + if (init_weight_to_optim_[key]) { + continue; + } else { + init_weight_to_optim_[key] = true; + } ps_->InitWeightKeyToOptims(key, val); } } @@ -224,12 +246,26 @@ void ParameterServer::ServerHandler::HandleInitWeightToOptimId(const ::ps::KV template void ParameterServer::ServerHandler::HandleInitInputsShape(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res) { + std::unique_lock lock(ps_->mutex()); + const Key &key = req_data.keys[0]; + if (init_optim_info_[key]) { + return; + } else { + init_optim_info_[key] = true; + } ps_->InitOptimInputsShape(req_data.keys, req_data.vals, req_data.lens); } template void ParameterServer::ServerHandler::HandleInitEmbeddings(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res) { + std::unique_lock lock(ps_->mutex()); + const Key &key = req_data.keys[0]; + if (init_weights_[key]) { + return; + } else { + init_weights_[key] = true; + } std::shared_ptr>>> shapes = std::make_shared>>>(); std::shared_ptr> input_shape = std::make_shared>(); @@ -239,7 +275,6 @@ void ParameterServer::ServerHandler::HandleInitEmbeddings(const ::ps::KVMeta shapes->push_back(indices_shape); shapes->push_back(output_shape); - const Key &key = req_data.keys[0]; const Lengths &lens = req_data.lens; size_t index = 0; for (int i = 0; i < lens[0]; i++) { @@ -254,6 +289,26 @@ void ParameterServer::ServerHandler::HandleInitEmbeddings(const ::ps::KVMeta ps_->InitEmbeddingTable(key, shapes); } +template +void ParameterServer::ServerHandler::HandleCheckReadyForPush(const ::ps::KVMeta &req_meta, + const ::ps::KVPairs &req_data, + ::ps::KVPairs *res) { + const Key &key = req_data.keys[0]; + bool ready = ps_->ReadyForPush(key); + res->keys.push_back(key); + res->vals.push_back(ready); +} + +template +void ParameterServer::ServerHandler::HandleCheckReadyForPull(const ::ps::KVMeta &req_meta, + const ::ps::KVPairs &req_data, + ::ps::KVPairs *res) { + const Key &key = req_data.keys[0]; + bool ready = ps_->ReadyForPull(key); + res->keys.push_back(key); + res->vals.push_back(ready); +} + template void ParameterServer::ServerHandler::HandleEmbeddingLookup(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res) { @@ -365,6 +420,8 @@ void ParameterServer::InitWeight(const Key &key, const WeightPtr &weight) { MS_LOG(INFO) << "Initializing weight for key " << key; if (weights_.count(key) == 0) { weights_[key] = weight; + tokens_[key] = 0; + is_embedding_[key] = false; } } @@ -399,6 +456,8 @@ void ParameterServer::InitEmbeddingTable( embedding_data[i] = random(engine); } weights_[key] = embedding; + tokens_[key] = 0; + is_embedding_[key] = true; grads_accum_counter_[key] = 0; } @@ -439,17 +498,17 @@ void ParameterServer::UpdateWeights() { optim_info->ComputeMean(worker_num_); optimizer->Execute(inputs, workspaces, outputs); optim_info->Reset(); + if (!is_embedding_[key]) { + tokens_[key] = worker_num_; + } } ResetGradAccumCount(); - accum_grads_cv_.notify_all(); } } template void ParameterServer::AccumGrad(const Keys &keys, const Values &values, const Lengths &lengths) { std::unique_lock lock(mutex_); - accum_grads_cv_.wait(lock, [this] { return this->ReadyForAccumGrads(); }); - const Key &key = keys[0]; std::shared_ptr optim_info = optim_infos_[key]; @@ -482,14 +541,13 @@ void ParameterServer::AccumGrad(const Keys &keys, const Values &values, const template WeightPtr ParameterServer::weight(const Key &key) { std::unique_lock lock(mutex_); - if (weights_.count(key) == 0) { - MS_LOG(ERROR) << "Invalid weight key " << key; - return nullptr; + MS_LOG(EXCEPTION) << "Invalid weight key " << key; } WeightPtr weight_ptr = weights_[key]; WeightPtr copy_weight_ptr = std::make_shared<::ps::SArray>(weight_ptr->size(), 0); copy_weight_ptr->CopyFrom(weight_ptr->data(), weight_ptr->size()); + tokens_[key] -= 1; return copy_weight_ptr; } @@ -560,12 +618,22 @@ inline bool ParameterServer::ReadyForUpdateWeights() { } template -inline bool ParameterServer::ReadyForAccumGrads() { +inline bool ParameterServer::ReadyForPush(const Key &key) { + std::unique_lock lock(mutex_); if (weights_.empty()) { MS_LOG(EXCEPTION) << "The weights in server is empty. Many reasons could cause this: 1.The Worker didn't send " "kInitWeightsCmd command. 2.The Server failed to initialize weights."; } - return grad_accum_count_ < weights_.size(); + return grad_accum_count_ < weights_.size() && tokens_[key] <= 0; +} + +template +inline bool ParameterServer::ReadyForPull(const Key &key) { + std::unique_lock lock(mutex_); + if (tokens_.count(key) == 0 || weights_[key] == 0) { + MS_LOG(EXCEPTION) << "Invalid weight key " << key; + } + return tokens_[key] > 0; } template @@ -576,6 +644,11 @@ inline void ParameterServer::ResetGradAccumCount() { } } +template +inline std::mutex &ParameterServer::mutex() { + return mutex_; +} + template void ParameterServer::Run(const FuncGraphPtr &func_graph) { ::ps::Start(0); diff --git a/mindspore/ccsrc/frontend/parallel/ps/worker.h b/mindspore/ccsrc/frontend/parallel/ps/worker.h index bde21904bb..ec220d5eff 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/worker.h +++ b/mindspore/ccsrc/frontend/parallel/ps/worker.h @@ -99,18 +99,30 @@ void Worker::Push(const std::vector &keys, std::vector add ::ps::SArray total_buffer(total_size, 0); size_t offset = 0; for (size_t i = 0; i < sizes.size(); i++) { - memcpy_s(total_buffer.data() + offset / sizeof(T), sizes[i] * sizeof(T), reinterpret_cast(addrs[i]), - sizes[i] * sizeof(T)); + auto ret = memcpy_s(total_buffer.data() + offset / sizeof(T), sizes[i] * sizeof(T), + reinterpret_cast(addrs[i]), sizes[i] * sizeof(T)); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + } offset += sizes[i] * sizeof(T); } + while (!kv_worker_->IsReadyForPush(keys[0])) { + continue; + } kv_worker_->PushData(::ps::SArray<::ps::Key>(keys), total_buffer, ::ps::SArray(sizes)); } template void Worker::Pull(const size_t key, void *dev_addr, const size_t size) { ::ps::SArray variables(size / sizeof(T), 0); + while (!kv_worker_->IsReadyForPull(key)) { + continue; + } kv_worker_->Wait(kv_worker_->ZPull({key}, &variables)); - memcpy_s(dev_addr, size, variables.data(), size); + auto ret = memcpy_s(dev_addr, size, variables.data(), size); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + } } template diff --git a/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h b/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h index 6ac7f6322d..7f73081ab7 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h +++ b/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h @@ -56,6 +56,8 @@ class WorkerProxy : public ::ps::KVWorker { int priority = 0); int InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, const ::ps::SArray &lens = {}, const Callback &cb = nullptr, int priority = 0); + bool IsReadyForPush(const Key &key); + bool IsReadyForPull(const Key &key); void PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, const ::ps::SArray &lens = {}, int cmd = 0, int priority = 0); void Finalize(); @@ -134,6 +136,28 @@ int WorkerProxy::InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, cons return ts; } +template +bool WorkerProxy::IsReadyForPush(const Key &key) { + ::ps::SArray result(1, 0); + this->Wait(this->ZPull({key}, &result, nullptr, kCheckReadyForPushCmd)); + if (result[0] > 0) { + return true; + } else { + return false; + } +} + +template +bool WorkerProxy::IsReadyForPull(const Key &key) { + ::ps::SArray result(1, 0); + this->Wait(this->ZPull({key}, &result, nullptr, kCheckReadyForPullCmd)); + if (result[0] > 0) { + return true; + } else { + return false; + } +} + template void WorkerProxy::PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, const ::ps::SArray &lens, int cmd, int priority) {