Add ps module in batches

This commit is contained in:
ZPaC 2020-07-09 18:55:20 +08:00
parent 4645a43e08
commit 0a1d3f1542
1 changed files with 559 additions and 0 deletions

View File

@ -0,0 +1,559 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_PARAMETER_SERVER_H_
#define MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_PARAMETER_SERVER_H_
#include <unistd.h>
#include <unordered_map>
#include <string>
#include <iostream>
#include <memory>
#include <vector>
#include <mutex>
#include <condition_variable>
#include <thread>
#include <cmath>
#include <random>
#include "ir/func_graph.h"
#include "session/session_basic.h"
#include "session/kernel_graph.h"
#include "session/anf_runtime_algorithm.h"
#include "session/session_factory.h"
#include "parallel/ps/common.h"
#include "parallel/ps/optimizer_info.h"
#include "parallel/ps/optimizer_info_builder.h"
#include "parallel/ps/util.h"
#include "device/cpu/kernel_select_cpu.h"
#include "utils/context/ms_context.h"
#include "kernel/kernel.h"
#include "kernel/ps/pserver_kernel.h"
#include "kernel/cpu/cpu_kernel_factory.h"
#include "kernel/ps/sparse_apply_adam_ps_kernel.h"
#include "kernel/ps/sparse_apply_ftrl_ps_kernel.h"
#include "kernel/ps/apply_momentum_ps_kernel.h"
#include "kernel/ps/embedding_look_up_ps_kernel.h"
namespace mindspore {
namespace parallel {
namespace ps {
using mindspore::kernel::ps::PServerKernel;
template <typename T>
class ParameterServer {
public:
static ParameterServer &GetInstance() {
static ParameterServer instance;
return instance;
}
void Run(const FuncGraphPtr &func_graph);
private:
ParameterServer()
: pserver_num_(0),
worker_num_(0),
rank_id_(0),
grad_accum_count_(0),
ps_(new ::ps::KVServer<T>(0)),
handler_(nullptr),
func_graph_(nullptr),
kernel_graph_(nullptr),
sess_(nullptr),
thread_(nullptr) {}
~ParameterServer() = default;
ParameterServer(const ParameterServer &) = delete;
ParameterServer &operator=(const ParameterServer &) = delete;
struct ServerHandler {
explicit ServerHandler(ParameterServer *ps) : ps_(ps) {}
void operator()(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVServer<T> *server);
void HandlePushReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data);
void HandlePullReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
void HandleInitWeights(const ::ps::KVPairs<T> &req_data);
void HandleInitWeightToOptimId(const ::ps::KVPairs<T> &req_data);
void HandleInitInputsShape(const ::ps::KVPairs<T> &req_data);
void HandleInitEmbeddings(const ::ps::KVPairs<T> &req_data);
void HandleEmbeddingLookup(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
ParameterServer *ps_;
};
bool Init(const FuncGraphPtr &func_graph);
void InitOptimInfoBuilders();
void InitWeightKeyToOptims(const Key &key, const int &optim_id);
void InitOptimInputsShape(const Keys &keys, const Values &values, const Lengths &lengths);
void InitWeight(const Key &key, const WeightPtr &weight);
void InitGrad(const Key &key, const GradPtr &grad);
void InitEmbeddingTable(const Key &key,
const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes);
void UpdateWeights();
void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths);
WeightPtr weight(const Key &key);
void DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, ::ps::KVPairs<T> *res);
int SumOfShapes(const std::vector<int> &shapes) const;
size_t PreComputeCapacity(const Keys &keys, const Lengths &lens);
bool ReadyForUpdateWeights();
bool ReadyForAccumGrads();
void ResetGradAccumCount();
size_t pserver_num_;
size_t worker_num_;
size_t rank_id_;
size_t grad_accum_count_;
std::unique_ptr<::ps::KVServer<T>> ps_;
std::unique_ptr<ServerHandler> handler_;
FuncGraphPtr func_graph_;
std::shared_ptr<session::KernelGraph> kernel_graph_;
std::shared_ptr<session::SessionBasic> sess_;
std::unordered_map<std::string, std::shared_ptr<PServerKernel>> optimizers_;
std::unordered_map<Key, InputsShapePtr> optim_inputs_shape_;
std::unordered_map<Key, std::shared_ptr<OptimizerInfo>> optim_infos_;
std::unordered_map<std::string, std::shared_ptr<OptimizerInfoBuilder>> optim_info_builders_;
std::unordered_map<Key, std::string> weight_key_to_optims_;
std::unordered_map<Key, WeightPtr> weights_;
std::unordered_map<Key, WeightPtr> grads_;
std::unordered_map<Key, size_t> grads_accum_counter_;
// std::unordered_map<Key, EmbeddingTablePtr> embeddings_;
std::unordered_map<Key, std::shared_ptr<PServerKernel>> embedding_lookup_ops_;
std::unordered_map<Key, size_t> embedding_row_lens_;
T learning_rate_;
T momentum_;
std::mutex mutex_;
std::condition_variable apply_grads_cv_;
std::condition_variable accum_grads_cv_;
std::unique_ptr<std::thread> thread_;
friend struct ServerHandler;
};
class FuncGraph;
template <typename T>
void ParameterServer<T>::ServerHandler::operator()(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data,
::ps::KVServer<T> *server) {
::ps::KVPairs<T> res;
if (req_meta.cmd == kInitWeightsCmd) {
MS_LOG(ERROR) << "handle init weights cmd" << std::endl;
HandleInitWeights(req_data);
} else if (req_meta.cmd == kInitWeightToOptimIdCmd) {
MS_LOG(ERROR) << "handle init weight optim id mapping cmd" << std::endl;
HandleInitWeightToOptimId(req_data);
} else if (req_meta.cmd == kInitOptimInputsShapeCmd) {
MS_LOG(ERROR) << "handle init inputs shape cmd" << std::endl;
HandleInitInputsShape(req_data);
} else if (req_meta.cmd == kInitEmbeddingsCmd) {
MS_LOG(ERROR) << "handle init embedding cmd" << std::endl;
HandleInitEmbeddings(req_data);
} else if (req_meta.cmd == kEmbeddingLookupCmd) {
MS_LOG(ERROR) << "handle embedding lookup cmd" << std::endl;
HandleEmbeddingLookup(req_meta, req_data, &res);
} else if (req_meta.push) {
MS_LOG(ERROR) << "handle push req cmd" << std::endl;
HandlePushReq(req_meta, req_data);
} else {
MS_LOG(ERROR) << "handle pull req cmd" << std::endl;
HandlePullReq(req_meta, req_data, &res);
}
server->Response(req_meta, res);
}
template <typename T>
void ParameterServer<T>::ServerHandler::HandlePushReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data) {
ps_->AccumGrad(req_data.keys, req_data.vals, req_data.lens);
}
template <typename T>
void ParameterServer<T>::ServerHandler::HandlePullReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data,
::ps::KVPairs<T> *res) {
res->keys = req_data.keys;
::ps::Key key = req_data.keys[0];
res->vals = *(ps_->weight(key));
}
template <typename T>
void ParameterServer<T>::ServerHandler::HandleInitWeights(const ::ps::KVPairs<T> &req_data) {
size_t key_num = req_data.keys.size();
T *data_ptr = req_data.vals.data();
size_t pos = 0;
for (size_t i = 0; i < key_num; i++) {
Key key = req_data.keys[i];
size_t data_len = req_data.lens.size() != key_num ? req_data.vals.size() / key_num : req_data.lens[i];
WeightPtr weight_ptr = std::make_shared<::ps::SArray<T>>();
weight_ptr->CopyFrom(data_ptr + pos, data_len);
ps_->InitWeight(key, weight_ptr);
GradPtr grad_ptr = std::make_shared<::ps::SArray<T>>(data_len, 0);
ps_->InitGrad(key, grad_ptr);
pos += data_len;
}
}
template <typename T>
void ParameterServer<T>::ServerHandler::HandleInitWeightToOptimId(const ::ps::KVPairs<T> &req_data) {
size_t key_num = req_data.keys.size();
for (size_t i = 0; i < key_num; i++) {
Key key = req_data.keys[i];
T val = req_data.vals[i];
ps_->InitWeightKeyToOptims(key, val);
}
}
template <typename T>
void ParameterServer<T>::ServerHandler::HandleInitInputsShape(const ::ps::KVPairs<T> &req_data) {
ps_->InitOptimInputsShape(req_data.keys, req_data.vals, req_data.lens);
}
template <typename T>
void ParameterServer<T>::ServerHandler::HandleInitEmbeddings(const ::ps::KVPairs<T> &req_data) {
std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> shapes =
std::make_shared<std::vector<std::shared_ptr<std::vector<size_t>>>>();
std::shared_ptr<std::vector<size_t>> input_shape = std::make_shared<std::vector<size_t>>();
std::shared_ptr<std::vector<size_t>> indices_shape = std::make_shared<std::vector<size_t>>();
std::shared_ptr<std::vector<size_t>> output_shape = std::make_shared<std::vector<size_t>>();
shapes->push_back(input_shape);
shapes->push_back(indices_shape);
shapes->push_back(output_shape);
const Key &key = req_data.keys[0];
const Lengths &lens = req_data.lens;
size_t index = 0;
for (int i = 0; i < lens[0]; i++) {
input_shape->push_back(static_cast<size_t>(req_data.vals[index++]));
}
for (int j = 0; j < lens[1]; j++) {
indices_shape->push_back(static_cast<size_t>(req_data.vals[index++]));
}
for (int k = 0; k < lens[2]; k++) {
output_shape->push_back(static_cast<size_t>(req_data.vals[index++]));
}
ps_->InitEmbeddingTable(key, shapes);
}
template <typename T>
void ParameterServer<T>::ServerHandler::HandleEmbeddingLookup(const ::ps::KVMeta &req_meta,
const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res) {
const Key &key = req_data.keys[0];
ps_->DoEmbeddingLookup(key, req_data.vals, res);
for (size_t i = 0; i < req_data.vals.size(); i++) {
res->keys->push_back(req_data.vals[i]);
}
}
template <typename T>
bool ParameterServer<T>::Init(const FuncGraphPtr &func_graph) {
const char *server_num = getenv(kEnvPServerNum);
const char *worker_num = getenv(kEnvWorkerNum);
if (server_num != nullptr) {
pserver_num_ = *server_num - '0';
}
if (worker_num != nullptr) {
worker_num_ = *worker_num - '0';
}
func_graph_ = func_graph;
rank_id_ = ::ps::MyRank();
handler_.reset(new ServerHandler(this));
InitOptimInfoBuilders();
ps_->set_request_handle(*handler_);
thread_.reset(new std::thread(&ParameterServer::UpdateWeights, this));
return true;
}
template <typename T>
void ParameterServer<T>::InitOptimInfoBuilders() {
std::shared_ptr<OptimizerInfoBuilder> momentum_info_builder = std::make_shared<MomentumOptimInfoBuilder>();
std::shared_ptr<OptimizerInfoBuilder> sparse_adam_info_builder = std::make_shared<SparseAdamOptimInfoBuilder>();
std::shared_ptr<OptimizerInfoBuilder> sparse_ftrl_info_builder = std::make_shared<SparseFtrlOptimInfoBuilder>();
optim_info_builders_[kApplyMomentum] = momentum_info_builder;
optim_info_builders_[kSparseAdam] = sparse_adam_info_builder;
optim_info_builders_[kSparseFtrl] = sparse_ftrl_info_builder;
}
template <typename T>
void ParameterServer<T>::InitWeightKeyToOptims(const Key &key, const int &optim_id) {
if (weight_key_to_optims_.count(key) > 0 || Util::optimizer_name(key) == "") {
return;
}
weight_key_to_optims_[key] = Util::optimizer_name(optim_id);
}
template <typename T>
void ParameterServer<T>::InitOptimInputsShape(const Keys &keys, const Values &values, const Lengths &lengths) {
InputsShapePtr inputs_shape = std::make_shared<InputsShape>();
int val_idx = 0;
const Key &key = keys[0];
if (optim_inputs_shape_.count(key) == 0) {
optim_inputs_shape_[key] = inputs_shape;
}
for (size_t i = 0; i < keys.size(); i++) {
auto shape = std::make_shared<std::vector<size_t>>();
inputs_shape->push_back(shape);
int len = lengths[i];
for (int j = 0; j < len; j++) {
shape->push_back(values[val_idx++]);
}
}
if (weight_key_to_optims_.count(key) > 0) {
const std::string &optim_name = weight_key_to_optims_[key];
if (optimizers_.count(optim_name) == 0 && optim_inputs_shape_.count(key) > 0) {
if (optim_name == kSparseAdam) {
std::shared_ptr<PServerKernel> optimizer =
std::make_shared<kernel::ps::SparseApplyAdamPSKernel>(rank_id_, pserver_num_);
optimizer->InitKernel(optim_inputs_shape_[key]);
optimizers_[optim_name] = optimizer;
} else if (optim_name == kApplyMomentum) {
std::shared_ptr<PServerKernel> optimizer =
std::make_shared<kernel::ps::ApplyMomentumPSKernel>(rank_id_, pserver_num_);
optimizer->InitKernel(optim_inputs_shape_[key]);
optimizers_[optim_name] = optimizer;
} else if (optim_name == kSparseFtrl) {
std::shared_ptr<PServerKernel> optimizer =
std::make_shared<kernel::ps::SparseApplyFtrlPSKernel>(rank_id_, pserver_num_);
optimizer->InitKernel(optim_inputs_shape_[key]);
optimizers_[optim_name] = optimizer;
}
}
}
}
template <typename T>
void ParameterServer<T>::InitWeight(const Key &key, const WeightPtr &weight) {
if (weights_.count(key) == 0) {
weights_[key] = weight;
}
}
template <typename T>
void ParameterServer<T>::InitGrad(const Key &key, const GradPtr &grad) {
if (grads_.count(key) == 0) {
grads_[key] = grad;
grads_accum_counter_[key] = 0;
}
}
template <typename T>
void ParameterServer<T>::InitEmbeddingTable(
const Key &key, const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes) {
// Init embedding lookup kernel
std::shared_ptr<PServerKernel> lookup = std::make_shared<kernel::ps::EmbeddingLookUpPSKernel>(rank_id_, pserver_num_);
lookup->InitKernel(shapes);
embedding_lookup_ops_[key] = lookup;
// Init embedding weight
const std::vector<size_t> &input_shapes = lookup->input_sizes();
size_t total_dims = 1;
for (auto shape : input_shapes) {
total_dims *= shape;
}
WeightPtr embedding = std::make_shared<Weight>(total_dims, 0.01);
weights_[key] = embedding;
grads_accum_counter_[key] = 0;
}
template <typename T>
void ParameterServer<T>::UpdateWeights() {
while (true) {
std::unique_lock<std::mutex> lock(mutex_);
apply_grads_cv_.wait(lock, [this] { return this->ReadyForUpdateWeights(); });
for (auto iter = weights_.begin(); iter != weights_.end(); iter++) {
Key key = iter->first;
WeightPtr weight_ptr = iter->second;
std::shared_ptr<PServerKernel> optimizer = nullptr;
if (weight_key_to_optims_.count(key) > 0) {
const std::string &optim_name = weight_key_to_optims_[key];
optimizer = optimizers_[optim_name];
}
MS_EXCEPTION_IF_NULL(optimizer);
std::shared_ptr<OptimizerInfo> optim_info = optim_infos_[key];
if (optim_info == nullptr) {
continue;
}
const WeightPtr &weight = weights_[key];
optim_info->UpdateWeight(weight);
const std::vector<kernel::AddressPtr> &inputs = optim_info->inputs();
const std::vector<kernel::AddressPtr> &workspaces = optim_info->workspaces();
const std::vector<kernel::AddressPtr> &outputs = optim_info->outputs();
optimizer->Execute(inputs, workspaces, outputs);
optim_info->Reset();
}
ResetGradAccumCount();
accum_grads_cv_.notify_all();
}
}
template <typename T>
void ParameterServer<T>::AccumGrad(const Keys &keys, const Values &values, const Lengths &lengths) {
std::unique_lock<std::mutex> lock(mutex_);
accum_grads_cv_.wait(lock, [this] { return this->ReadyForAccumGrads(); });
const Key &key = keys[0];
std::shared_ptr<OptimizerInfo> optim_info = optim_infos_[key];
// Create or update the optimizer info
if (optim_info == nullptr) {
const std::shared_ptr<OptimizerInfoBuilder> &builder = optim_info_builders_[weight_key_to_optims_[key]];
std::shared_ptr<kernel::ps::PServerKernel> pserver_kernel = optimizers_[weight_key_to_optims_[key]];
if (pserver_kernel == nullptr) {
MS_LOG(EXCEPTION) << "no optimizer found for key " << key << " optim name " << weight_key_to_optims_[key];
}
MS_EXCEPTION_IF_NULL(pserver_kernel);
OptimizerInfo *optim =
builder->Build(pserver_kernel, weights_[key], keys, values, lengths, optim_inputs_shape_[key], worker_num_);
optim_info.reset(optim);
optim_infos_[key] = optim_info;
} else {
optim_info->Update(values, lengths);
}
MS_EXCEPTION_IF_NULL(optim_info);
optim_info->Accumulate(values, lengths);
grads_accum_counter_[key] += 1;
if (grads_accum_counter_[key] == worker_num_) {
grad_accum_count_++;
}
if (ReadyForUpdateWeights()) {
apply_grads_cv_.notify_one();
}
}
template <typename T>
WeightPtr ParameterServer<T>::weight(const Key &key) {
std::unique_lock<std::mutex> lock(mutex_);
if (weights_.count(key) == 0) {
MS_LOG(ERROR) << "Invalid weight key " << key;
return nullptr;
}
WeightPtr weight_ptr = weights_[key];
WeightPtr copy_weight_ptr = std::make_shared<::ps::SArray<T>>(weight_ptr->size(), 0);
copy_weight_ptr->CopyFrom(weight_ptr->data(), weight_ptr->size());
return copy_weight_ptr;
}
template <typename T>
void ParameterServer<T>::DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, ::ps::KVPairs<T> *res) {
std::unique_lock<std::mutex> lock(mutex_);
if (weights_.count(key) == 0) {
MS_LOG(ERROR) << "Invalid embedding table key " << key;
return;
}
if (embedding_lookup_ops_.count(key) == 0) {
MS_LOG(ERROR) << "Invalid embedding lookup op key " << key;
return;
}
WeightPtr table_ptr = weights_[key];
std::shared_ptr<PServerKernel> table_lookup_op = embedding_lookup_ops_[key];
// Update shapes of lookup operator
std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> shapes =
std::make_shared<std::vector<std::shared_ptr<std::vector<size_t>>>>();
std::shared_ptr<std::vector<size_t>> indices_shape = std::make_shared<std::vector<size_t>>();
indices_shape->emplace_back(lookup_ids.size());
shapes->push_back(indices_shape);
table_lookup_op->ReInit(shapes);
const std::vector<size_t> output_shapes = table_lookup_op->output_sizes();
std::vector<kernel::AddressPtr> inputs;
AddressPtr embedding_table = std::make_shared<kernel::Address>();
AddressPtr indices = std::make_shared<kernel::Address>();
inputs.push_back(embedding_table);
inputs.push_back(indices);
embedding_table->addr = table_ptr->data();
embedding_table->size = table_ptr->size() * sizeof(T);
indices->addr = lookup_ids.data();
indices->size = lookup_ids.size() * sizeof(T);
std::vector<kernel::AddressPtr> workspaces;
std::vector<kernel::AddressPtr> outputs;
AddressPtr output = std::make_shared<kernel::Address>();
std::shared_ptr<Values> addr = std::make_shared<Values>(output_shapes[0] / sizeof(T), 0);
output->addr = addr->data();
output->size = output_shapes[0];
outputs.push_back(output);
table_lookup_op->Execute(inputs, workspaces, outputs);
res->vals = *addr;
res->lens.push_back(res.vals.size());
}
template <typename T>
int ParameterServer<T>::SumOfShapes(const std::vector<int> &shapes) const {
int sum = 1;
for (auto shape : shapes) {
sum *= shape;
}
return sum;
}
template <typename T>
size_t ParameterServer<T>::PreComputeCapacity(const Keys &keys, const Lengths &lens) {
size_t capacity = 0;
for (size_t i = 0; i < keys.size(); i++) {
Key key = keys[i];
if (embedding_row_lens_.count(key) > 0) {
capacity += embedding_row_lens_[key] * lens[i];
} else {
MS_LOG(ERROR) << "Invalid embedding lookup id " << key;
}
}
return capacity;
}
template <typename T>
inline bool ParameterServer<T>::ReadyForUpdateWeights() {
return grads_accum_counter_.size() > 0 && grad_accum_count_ == grads_accum_counter_.size();
}
template <typename T>
inline bool ParameterServer<T>::ReadyForAccumGrads() {
return grad_accum_count_ < weights_.size();
}
template <typename T>
inline void ParameterServer<T>::ResetGradAccumCount() {
grad_accum_count_ = 0;
for (auto iter = grads_accum_counter_.begin(); iter != grads_accum_counter_.end(); iter++) {
grads_accum_counter_[iter->first] = 0;
}
}
template <typename T>
void ParameterServer<T>::Run(const FuncGraphPtr &func_graph) {
::ps::Start(0);
if (!::ps::IsServer()) {
std::cout << "This is not ther Server" << std::endl;
return;
}
Init(func_graph);
thread_->join();
}
} // namespace ps
} // namespace parallel
} // namespace mindspore
#endif // MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_PARAMETER_SERVER_H_