replace ps-lite

This commit is contained in:
chendongsheng 2021-02-27 17:10:00 +08:00
parent e99c29c7d9
commit db0a6f1e19
50 changed files with 415 additions and 2995 deletions

View File

@ -1,22 +0,0 @@
if(ENABLE_GITEE)
set(REQ_URL "https://gitee.com/mirrors/ps-lite/repository/archive/34fd45cae457d59850fdcb2066467778d0673f21.zip")
set(MD5 "0d1543b8dcb0bc3610637e1643c94eb4")
else()
set(REQ_URL "https://github.com/dmlc/ps-lite/archive/34fd45cae457d59850fdcb2066467778d0673f21.zip")
set(MD5 "393c0e27b68bfaf96718caa3aa96f5a3")
endif()
set(pslite_USE_STATIC_LIBS ON)
if(${ENABLE_IBVERBS} STREQUAL "ON")
set(pslite_CXXFLAGS "USE_IBVERBS=1")
endif()
mindspore_add_pkg(pslite
LIBS ps
URL ${REQ_URL}
MD5 ${MD5}
PATCHES ${CMAKE_SOURCE_DIR}/third_party/patch/pslite/ps_lite.patch001
ONLY_MAKE True
ONLY_MAKE_INCS include/*
ONLY_MAKE_LIBS build/*)
include_directories(${pslite_INC})
add_library(mindspore::pslite ALIAS pslite::ps)

View File

@ -1,5 +0,0 @@
mindspore_add_pkg(zeromq
VER 4.1.4
HEAD_ONLY ./
URL https://raw.githubusercontent.com/mli/deps/master/build/zeromq-4.1.4.tar.gz
MD5 a611ecc93fffeb6d058c0e6edf4ad4fb)

View File

@ -32,10 +32,6 @@ include(${CMAKE_SOURCE_DIR}/cmake/external_libs/flatbuffers.cmake)
if(USE_GLOG)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/glog.cmake)
endif()
if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/zeromq.cmake)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/pslite.cmake)
endif()
find_package(Python3)
include_directories(${Python3_INCLUDE_DIRS})

View File

@ -339,8 +339,8 @@ elseif(CMAKE_SYSTEM_NAME MATCHES "Darwin")
target_link_libraries(_c_expression PRIVATE -Wl,-force_load mindspore -Wl,-noall_load)
else()
if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))
target_link_libraries(mindspore mindspore::pslite proto_input mindspore::protobuf
mindspore::event mindspore::event_pthreads ${zeromq_DIRPATH}/zmq_install/lib/libzmq.a)
target_link_libraries(mindspore proto_input mindspore::protobuf
mindspore::event mindspore::event_pthreads)
target_link_libraries(mindspore -Wl,--no-as-needed mindspore::event_core ps_cache)
if(${ENABLE_IBVERBS} STREQUAL "ON")
target_link_libraries(mindspore ibverbs rdmacm)

View File

@ -17,6 +17,7 @@
#include <vector>
#include <algorithm>
#include "ps/worker.h"
#include "ps/util.h"
namespace mindspore {
namespace kernel {
@ -35,7 +36,7 @@ void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) {
<< input_shape << " is too large.";
}
if (mindspore::ps::Util::IsRoleOfWorker()) {
if (mindspore::ps::PSContext::instance()->is_worker()) {
key_ = AnfAlgo::GetNodeAttr<size_t>(kernel_node, kAttrPsKey);
}
std::vector<size_t> keys{key_, key_, key_};
@ -50,9 +51,10 @@ void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) {
<< ", indices_shape:" << indices_shape << ", output_shape:" << output_shape;
std::vector<int64_t> lens{SizeToLong(input_shape.size()), SizeToLong(indices_shape.size()),
SizeToLong(output_shape.size())};
if (mindspore::ps::Util::IsRoleOfWorker()) {
if (mindspore::ps::PSContext::instance()->is_worker()) {
mindspore::ps::worker.AddEmbeddingTable(key_, input_shape[axis]);
mindspore::ps::worker.InitPSEmbeddingTable(keys, values, lens);
mindspore::ps::ParamInitInfoMessage info;
mindspore::ps::worker.InitPSEmbeddingTable(key_, input_shape, indices_shape, output_shape, info);
}
}
@ -70,17 +72,16 @@ bool EmbeddingLookUpProxyKernel::Launch(const std::vector<kernel::AddressPtr> &i
size_t input_size = inputs[1]->size;
size_t output_size = outputs[0]->size;
size_t size = input_size / sizeof(float);
::ps::SArray<int> lookup_ids(size, 0);
::ps::SArray<int> lengths{size};
::ps::SArray<float> lookup_result(output_size / sizeof(float), 0);
size_t size = input_size / sizeof(int);
std::vector<int> lookup_ids(size, 0);
std::vector<int> lengths{SizeToInt(size)};
std::vector<float> lookup_result(output_size / sizeof(float), 0);
auto ret = memcpy_s(lookup_ids.data(), lookup_ids.size() * sizeof(int), indices_addr, input_size);
if (ret != EOK) {
MS_LOG(EXCEPTION) << "Lookup id memcpy failed.";
return false;
}
mindspore::ps::worker.DoPSEmbeddingLookup({key_}, lookup_ids, lengths, &lookup_result,
mindspore::ps::kEmbeddingLookupCmd);
mindspore::ps::worker.DoPSEmbeddingLookup(key_, lookup_ids, &lookup_result, mindspore::ps::kEmbeddingLookupCmd);
auto ret2 = memcpy_s(output_addr, outputs[0]->size, lookup_result.data(), output_size);
if (ret2 != EOK) {

View File

@ -62,7 +62,7 @@ class PullKernel : public CPUKernel {
MS_EXCEPTION_IF_NULL(param_node);
param_name_ = param_node->fullname_with_scope();
if (mindspore::ps::Util::IsRoleOfWorker()) {
if (mindspore::ps::PSContext::instance()->is_worker()) {
key_ = AnfAlgo::GetNodeAttr<size_t>(kernel_node, kAttrPsKey);
}
InitSizeLists();

View File

@ -30,6 +30,7 @@
#include "backend/optimizer/pass/replace_node_by_proxy.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
#include "ps/util.h"
#include "ps/ps_context.h"
#endif
namespace mindspore {
@ -75,9 +76,9 @@ GraphId CPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr
MS_LOG(INFO) << "Set kernel info";
SetKernelInfo(graph.get());
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
if (ps::Util::IsParamServerMode()) {
if (ps::PSContext::instance()->is_ps_mode()) {
AssignParamKey(graph);
if (ps::Util::IsRoleOfWorker()) {
if (ps::PSContext::instance()->is_worker()) {
Optimize(graph);
}
}

View File

@ -41,8 +41,9 @@
#include "utils/trace_base.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
#include "ps/ps_cache/ps_cache_manager.h"
#include "ps/common.h"
#include "ps/constants.h"
#include "ps/util.h"
#include "ps/ps_context.h"
#include "abstract/abstract_value.h"
#endif
@ -2287,7 +2288,7 @@ void SessionBasic::RunOpHideNopNode(const KernelGraphPtr &kernel_graph) const {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
void SessionBasic::InitPsWorker(const KernelGraphPtr &kernel_graph) {
if (!ps::Util::IsRoleOfWorker()) {
if (!ps::PSContext::instance()->is_worker()) {
return;
}
CheckPSModeConsistence(kernel_graph);
@ -2384,7 +2385,7 @@ void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) {
void SessionBasic::InitPSParamAndOptim(const KernelGraphPtr &kernel_graph,
const std::vector<tensor::TensorPtr> &inputs_const) {
if (!ps::Util::IsRoleOfWorker()) {
if (!ps::PSContext::instance()->is_worker()) {
return;
}
std::vector<tensor::TensorPtr> inputs(inputs_const);

View File

@ -48,6 +48,7 @@
#include "mindspore/core/utils/parallel_node_check.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
#include "ps/util.h"
#include "ps/ps_context.h"
#endif
using mindspore::tensor::Tensor;
@ -3283,7 +3284,7 @@ static void HandleNoUsedParameter(const FuncGraphPtr &root) {
bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
if (ps::Util::IsRoleOfPServer() || ps::Util::IsRoleOfScheduler()) {
if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) {
return false;
}
#endif

View File

@ -288,7 +288,6 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Windows")
else()
target_link_libraries(_c_dataengine PRIVATE _c_mindrecord)
if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))
target_link_libraries(_c_dataengine PRIVATE mindspore::pslite ${zeromq_DIRPATH}/zmq_install/lib/libzmq.a)
if(${ENABLE_IBVERBS} STREQUAL "ON")
target_link_libraries(_c_dataengine PRIVATE ibverbs rdmacm)
endif()

View File

@ -460,7 +460,7 @@ bool StartPSWorkerAction(const ResourcePtr &res) {
bool StartPSServerAction(const ResourcePtr &res) {
FuncGraphPtr func_graph = res->func_graph();
auto &ps = ps::ParameterServer<float>::GetInstance();
auto &ps = ps::ParameterServer::GetInstance();
ps.Run(func_graph);
return true;
}
@ -626,7 +626,7 @@ std::vector<ActionItem> VmPipeline() {
actions.emplace_back(std::make_pair("validate", ValidateAction));
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
if (ps::Util::IsRoleOfWorker()) {
if (ps::PSContext::instance()->is_worker()) {
actions.emplace_back(std::make_pair("worker", StartPSWorkerAction));
}
#endif

View File

@ -43,6 +43,7 @@
#include "pipeline/jit/static_analysis/auto_monad.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
#include "ps/util.h"
#include "ps/ps_context.h"
#endif
namespace mindspore {
@ -406,7 +407,7 @@ bool AddRecomputationPass(const ResourcePtr &res) {
bool AddCacheEmbeddingPass(const ResourcePtr &res) {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
if (ps::Util::IsParamServerMode()) {
if (ps::PSContext::instance()->is_ps_mode()) {
return true;
}
#endif

View File

@ -49,7 +49,7 @@
#include "utils/shape_utils.h"
#include "utils/info.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
#include "ps/common.h"
#include "ps/constants.h"
#include "ps/util.h"
#include "ps/worker.h"
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
@ -492,14 +492,11 @@ std::vector<ActionItem> GetPipline(const ResourcePtr &resource, const std::strin
std::string backend = MsContext::GetInstance()->backend_policy();
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
if (mindspore::ps::Util::IsParamServerMode()) {
mindspore::ps::Util::SetInternalEnvVar();
}
if (ps::Util::IsRoleOfPServer()) {
if (ps::PSContext::instance()->is_server()) {
resource->results()[kBackend] = compile::CreateBackend();
return PServerPipeline();
}
if (ps::Util::IsRoleOfScheduler()) {
if (ps::PSContext::instance()->is_scheduler()) {
return PSchedulerPipeline();
}
#endif
@ -978,7 +975,7 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc
const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
const std::vector<int64_t> &input_indexes, bool need_run) {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
if ((ps::Util::IsParamServerMode()) && (!ps::Util::IsRoleOfWorker())) {
if ((ps::PSContext::instance()->is_ps_mode()) && (!ps::PSContext::instance()->is_worker())) {
return true;
}
#endif
@ -1030,7 +1027,7 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc
ConfigManager::GetInstance().set_iter_num(size);
// PS cache does not support loop sink.
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
if (ps::Util::IsRoleOfWorker() && ps::PsDataPrefetch::GetInstance().cache_enable()) {
if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) {
ps::PsDataPrefetch::GetInstance().CreateDataChannel(queue_name, LongToSize(size));
ConfigManager::GetInstance().set_iter_num(1);
}
@ -1151,10 +1148,11 @@ void ClearResAtexit() {
pynative::ClearPyNativeSession();
session::ClearPythonParasMap();
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
if (ps::Util::IsParamServerMode() && ps::Util::IsRoleOfWorker()) {
if (ps::PSContext::instance()->is_ps_mode() && ps::PSContext::instance()->is_worker()) {
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
ps::ps_cache_instance.Finalize();
}
MS_LOG(INFO) << "ps::worker.Finalize";
ps::worker.Finalize();
}
#endif

View File

@ -21,8 +21,8 @@ if(NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)))
list(REMOVE_ITEM _PS_SRC_FILES "core/abstract_node.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/scheduler_node.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/http_client.cc")
list(REMOVE_ITEM _PS_SRC_FILES "internal/worker.cc")
list(REMOVE_ITEM _PS_SRC_FILES "internal/parameter_server.cc")
list(REMOVE_ITEM _PS_SRC_FILES "worker.cc")
list(REMOVE_ITEM _PS_SRC_FILES "parameter_server.cc")
endif()
if(NOT ENABLE_D)

View File

@ -1,140 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_COMMON_H_
#define MINDSPORE_CCSRC_PS_COMMON_H_
#include <limits.h>
#include <iostream>
#include <vector>
#include <memory>
#include <map>
#include <string>
#include "ps/ps.h"
namespace mindspore {
namespace ps {
constexpr char kEnvCommType[] = "MS_COMM_TYPE";
constexpr char kEnvInterface[] = "MS_INTERFACE";
constexpr char kEnvPServerNum[] = "MS_SERVER_NUM";
constexpr char kEnvWorkerNum[] = "MS_WORKER_NUM";
constexpr char kEnvSchedulerHost[] = "MS_SCHED_HOST";
constexpr char kEnvSchedulerPort[] = "MS_SCHED_PORT";
constexpr char kDmlcCommType[] = "DMLC_PS_VAN_TYPE";
constexpr char kDmlcInterface[] = "DMLC_INTERFACE";
constexpr char kDmlcPServerNum[] = "DMLC_NUM_SERVER";
constexpr char kDmlcWorkerNum[] = "DMLC_NUM_WORKER";
constexpr char kDmlcRole[] = "DMLC_ROLE";
constexpr char kDmlcSchedulerHost[] = "DMLC_PS_ROOT_URI";
constexpr char kDmlcSchedulerPort[] = "DMLC_PS_ROOT_PORT";
constexpr char kCommTypeOfIBVerbs[] = "ibverbs";
constexpr char kCommTypeOfTCP[] = "zmq";
constexpr char kRoleOfPServer[] = "server";
constexpr char kRoleOfWorker[] = "worker";
constexpr char kRoleOfScheduler[] = "scheduler";
constexpr char kLearningRate[] = "learning_rate";
constexpr char kMomentum[] = "momentum";
constexpr char kApplyMomentum[] = "ApplyMomentum";
constexpr char kSparseAdam[] = "Adam";
constexpr char kSparseLazyAdam[] = "LazyAdam";
constexpr char kSparseFtrl[] = "Ftrl";
constexpr char kApplyMomentumOp[] = "Momentum";
constexpr char kSparseAdamOp[] = "Adam";
constexpr char kSparseLazyAdamOp[] = "LazyAdam";
constexpr char kSparseFtrlOp[] = "FTRL";
constexpr int64_t kInitWeightsCmd = 10;
constexpr int64_t kInitWeightToOptimIdCmd = 11;
constexpr int64_t kInitOptimInputsShapeCmd = 12;
constexpr int64_t kInitKeyToPushNodeIdCmd = 13;
constexpr int64_t kInitEmbeddingsCmd = 20;
constexpr int64_t kUpdateEmbeddingsCmd = 21;
constexpr int64_t kCheckReadyForPushCmd = 25;
constexpr int64_t kCheckReadyForPullCmd = 26;
constexpr int64_t kEmbeddingLookupCmd = 30;
constexpr int64_t kFinalizeCmd = 40;
constexpr size_t kInvalidKey = UINT64_MAX;
constexpr int64_t kInvalidID = -1;
using DataPtr = std::shared_ptr<unsigned char>;
using VectorPtr = std::shared_ptr<std::vector<unsigned char>>;
using Key = ::ps::Key;
using Keys = ::ps::SArray<Key>;
using Values = ::ps::SArray<float>;
using ValuesPtr = std::shared_ptr<Values>;
using Weight = ::ps::SArray<float>;
using Grad = ::ps::SArray<float>;
using LookupIds = ::ps::SArray<Key>;
using Lengths = ::ps::SArray<int>;
using WeightPtr = std::shared_ptr<Weight>;
using GradPtr = std::shared_ptr<Grad>;
using InputsShape = std::vector<std::shared_ptr<std::vector<size_t>>>;
using InputsShapePtr = std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>>;
constexpr size_t INDEX_NOT_SEND = UINT_MAX;
using OptimOriginIdx = std::map<std::string, size_t>;
using OptimPSSendIdx = std::map<std::string, size_t>;
const OptimOriginIdx kMomentumOriginIdx = {{"weight", 0}, {"accum", 1}, {"lr", 2}, {"grad", 3}, {"momentum", 4}};
const OptimPSSendIdx kMomentumPSSendIdx = {
{"weight", INDEX_NOT_SEND}, {"accum", INDEX_NOT_SEND}, {"lr", 0}, {"grad", 1}, {"momentum", 2}};
const OptimOriginIdx kSparseAdamOriginIdx = {{"weight", 0}, {"m", 1}, {"v", 2}, {"beta1_power", 3},
{"beta2_power", 4}, {"lr", 5}, {"beta1", 6}, {"beta2", 7},
{"eps", 8}, {"grad", 9}, {"indices", 10}};
const OptimPSSendIdx kSparseAdamPSSendIdx = {{"weight", INDEX_NOT_SEND},
{"m", INDEX_NOT_SEND},
{"v", INDEX_NOT_SEND},
{"beta1_power", 0},
{"beta2_power", 1},
{"lr", 2},
{"beta1", 3},
{"beta2", 4},
{"eps", 5},
{"grad", 6},
{"indices", 7}};
const OptimOriginIdx kSparseFtrlOriginIdx = {{"weight", 0}, {"accum", 1}, {"linear", 2}, {"grad", 3}, {"indices", 4}};
const OptimPSSendIdx kSparseFtrlPSSendIdx = {
{"weight", INDEX_NOT_SEND}, {"accum", INDEX_NOT_SEND}, {"linear", INDEX_NOT_SEND}, {"grad", 0}, {"indices", 1}};
const std::map<std::string, OptimOriginIdx> kOptimToOriginIdx = {{kApplyMomentum, kMomentumOriginIdx},
{kSparseAdam, kSparseAdamOriginIdx},
{kSparseLazyAdam, kSparseAdamOriginIdx},
{kSparseFtrl, kSparseFtrlOriginIdx}};
const std::map<std::string, OptimOriginIdx> kOptimToPSSendIdx = {{kApplyMomentum, kMomentumPSSendIdx},
{kSparseAdam, kSparseAdamPSSendIdx},
{kSparseLazyAdam, kSparseAdamPSSendIdx},
{kSparseFtrl, kSparseFtrlPSSendIdx}};
#define EXC_IF_VEC_IDX_OOB(vec, idx) \
{ \
size_t vec_size = vec.size(); \
if (idx >= vec_size) { \
MS_LOG(EXCEPTION) << "Vector " << #vec << " size is " << vec_size << ". So index " << idx \
<< " is out of bound."; \
} \
}
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_COMMON_H_

View File

@ -14,10 +14,11 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_INTERNAL_CONSTANTS_H_
#define MINDSPORE_CCSRC_PS_INTERNAL_CONSTANTS_H_
#ifndef MINDSPORE_CCSRC_PS_CONSTANTS_H_
#define MINDSPORE_CCSRC_PS_CONSTANTS_H_
#include <limits.h>
#include <climits>
#include <iostream>
#include <vector>
#include <memory>
@ -26,8 +27,6 @@
namespace mindspore {
namespace ps {
namespace internal {
constexpr char kEnvCommType[] = "MS_COMM_TYPE";
constexpr char kEnvInterface[] = "MS_INTERFACE";
constexpr char kEnvPServerNum[] = "MS_SERVER_NUM";
@ -127,7 +126,6 @@ const std::map<std::string, OptimOriginIdx> kOptimToPSSendIdx = {{kApplyMomentum
<< " is out of bound."; \
} \
}
} // namespace internal
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_INTERNAL_CONSTANTS_H_
#endif // MINDSPORE_CCSRC_PS_CONSTANTS_H_

View File

@ -39,9 +39,9 @@ void ClusterMetadata::Init(const uint32_t &worker_num, const uint32_t &server_nu
scheduler_port_ = scheduler_port;
}
uint32_t ClusterMetadata::worker_num() { return worker_num_; }
uint32_t ClusterMetadata::total_worker_num() { return worker_num_; }
uint32_t ClusterMetadata::server_num() { return server_num_; }
uint32_t ClusterMetadata::total_server_num() { return server_num_; }
uint32_t ClusterMetadata::heartbeat_interval() { return heartbeat_interval_; }

View File

@ -37,8 +37,8 @@ class ClusterMetadata {
void Init(const uint32_t &worker_num, const uint32_t &server_num, std::string scheduler_host,
const uint16_t &scheduler_port);
uint32_t worker_num();
uint32_t server_num();
uint32_t total_worker_num();
uint32_t total_server_num();
uint32_t heartbeat_interval();
void set_heartbeat_interval(const uint32_t &heartbeat_interval);
std::string scheduler_host();

View File

@ -122,9 +122,9 @@ std::string CommUtil::NodeRoleToString(const NodeRole &role) {
}
}
bool CommUtil::ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id) {
if (node_role == NodeRole::SERVER && (rank_id > ClusterMetadata::instance()->server_num() - 1)) {
if (node_role == NodeRole::SERVER && (rank_id > ClusterMetadata::instance()->total_server_num() - 1)) {
return false;
} else if (node_role == NodeRole::WORKER && (rank_id > ClusterMetadata::instance()->worker_num() - 1)) {
} else if (node_role == NodeRole::WORKER && (rank_id > ClusterMetadata::instance()->total_worker_num() - 1)) {
return false;
}
return true;

View File

@ -20,7 +20,7 @@ namespace mindspore {
namespace ps {
namespace core {
void NodeManager::InitNodeNum() {
total_node_num_ = ClusterMetadata::instance()->server_num() + ClusterMetadata::instance()->worker_num();
total_node_num_ = ClusterMetadata::instance()->total_server_num() + ClusterMetadata::instance()->total_worker_num();
}
int NodeManager::NextRankId(const RegisterMessage &register_message) {

View File

@ -1,179 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_INTERNAL_PARAMETER_SERVER_H_
#define MINDSPORE_CCSRC_PS_INTERNAL_PARAMETER_SERVER_H_
#include <unistd.h>
#include <unordered_map>
#include <string>
#include <iostream>
#include <memory>
#include <vector>
#include <mutex>
#include <condition_variable>
#include <thread>
#include <cmath>
#include <random>
#include <utility>
#include <list>
#include <map>
#include <functional>
#include "ir/func_graph.h"
#include "backend/session/session_basic.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/session/session_factory.h"
#include "ps/optimizer_info.h"
#include "ps/optimizer_info_builder.h"
#include "ps/ps_context.h"
#include "runtime/device/cpu/kernel_select_cpu.h"
#include "utils/ms_context.h"
#include "backend/kernel_compiler/kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#include "backend/kernel_compiler/cpu/ps/pserver_kernel.h"
#include "backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h"
#include "backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h"
#include "backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h"
#include "backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h"
#include "backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h"
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
#include "ps/random_normal/random_normal.h"
#include "ps/internal/constants.h"
#include "ps/util.h"
#include "ps/embedding_table_shard_metadata.h"
#include "utils/log_adapter.h"
#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
#include "ps/core/server_node.h"
namespace mindspore {
namespace ps {
namespace internal {
class ParameterServer {
public:
static ParameterServer &GetInstance() {
static ParameterServer instance;
return instance;
}
void Run(const FuncGraphPtr &func_graph);
private:
ParameterServer()
: pserver_num_(0),
worker_num_(0),
rank_id_(0),
grad_accum_count_(0),
handler_(nullptr),
func_graph_(nullptr),
sess_(nullptr),
running_(true),
thread_(nullptr) {}
~ParameterServer() = default;
ParameterServer(const ParameterServer &) = delete;
ParameterServer &operator=(const ParameterServer &) = delete;
class ServerHandler {
public:
explicit ServerHandler(ParameterServer *ps) : ps_(ps) {}
~ServerHandler() = default;
void Init();
void operator()(std::shared_ptr<core::TcpConnection> conn, std::shared_ptr<core::MessageMeta> meta, DataPtr data,
size_t size);
void HandlePushReq(DataPtr data, size_t size, VectorPtr res);
void HandlePullReq(DataPtr data, size_t size, VectorPtr res);
void HandleInitWeights(DataPtr data, size_t size, VectorPtr res);
void HandleInitWeightToOptimId(DataPtr data, size_t size, VectorPtr res);
void HandleInitInputsShape(DataPtr data, size_t size, VectorPtr res);
void HandleInitEmbeddings(DataPtr data, size_t size, VectorPtr res);
void HandleCheckReadyForPush(DataPtr data, size_t size, VectorPtr res);
void HandleCheckReadyForPull(DataPtr data, size_t size, VectorPtr res);
void HandleEmbeddingLookup(DataPtr data, size_t size, VectorPtr res);
void HandleUpdateEmbeddings(DataPtr data, size_t size, VectorPtr res);
void HandleFinalize(DataPtr data, size_t size, VectorPtr res);
private:
ParameterServer *ps_;
typedef void (ServerHandler::*RequestHandler)(DataPtr data, size_t size, VectorPtr res);
std::unordered_map<int, RequestHandler> handlers_;
std::unordered_map<Key, bool> init_weights_;
std::unordered_map<Key, bool> init_weight_to_optim_;
std::unordered_map<Key, bool> init_optim_info_;
};
bool Init(const FuncGraphPtr &func_graph);
void InitOptimInfoBuilders();
void InitWeightKeyToOptims(const Key &key, const int64_t &optim_id);
void InitOptimInputsShape(const Keys &keys, const Values &values, const Lengths &lengths);
void InitWeight(const Key &key, const WeightPtr &weight);
void InitGrad(const Key &key, const GradPtr &grad);
void InitEmbeddingTable(const Key &key,
const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes,
const ParamInitInfo &param_init_info);
bool HasWeight(const Key &key);
void Finalize();
void UpdateWeights();
void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths);
WeightPtr weight(const Key &key);
void DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, KVMessage *res);
void UpdateEmbeddings(const Key &key, const LookupIds &lookup_ids, const Values &vals);
bool ReadyForUpdateWeights();
bool ReadyForPush(const Key &key);
bool ReadyForPull(const Key &key);
void ResetGradAccumCount();
const CNodePtr GetCNode(const std::string &name) const;
std::mutex &mutex();
void GetEmbeddingTableParamPtr();
void SyncEmbeddingTables();
size_t pserver_num_;
size_t worker_num_;
size_t rank_id_;
size_t grad_accum_count_;
std::unique_ptr<ServerHandler> handler_;
FuncGraphPtr func_graph_;
std::shared_ptr<session::SessionBasic> sess_;
bool running_;
std::unordered_map<Key, std::shared_ptr<PServerKernel>> optimizers_;
std::unordered_map<Key, InputsShapePtr> optim_inputs_shape_;
std::unordered_map<Key, InputsShapePtr> original_optim_inputs_shape_;
std::unordered_map<Key, std::shared_ptr<OptimizerInfo>> optim_infos_;
std::unordered_map<std::string, std::shared_ptr<OptimizerInfoBuilder>> optim_info_builders_;
std::unordered_map<Key, std::string> weight_key_to_optims_;
std::unordered_map<Key, std::string> weight_key_to_optim_op_;
std::unordered_map<Key, WeightPtr> weights_;
std::unordered_map<Key, bool> is_embedding_;
std::unordered_map<Key, WeightPtr> grads_;
std::unordered_map<Key, size_t> grads_accum_counter_;
std::unordered_map<Key, std::shared_ptr<PServerKernel>> embedding_lookup_ops_;
std::unordered_map<Key, uint64_t> tokens_;
std::mutex mutex_;
std::condition_variable apply_grads_cv_;
std::unique_ptr<std::thread> thread_;
core::ServerNode server_node_;
std::map<Key, ParameterPtr> embedding_tables_;
friend class ServerHandler;
};
} // namespace internal
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_INTERNAL_PARAMETER_SERVER_H_

View File

@ -1,157 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_INTERNAL_WORKER_H_
#define MINDSPORE_CCSRC_PS_INTERNAL_WORKER_H_
#include <utility>
#include <memory>
#include <vector>
#include <string>
#include <numeric>
#include <functional>
#include <algorithm>
#include <map>
#include <mutex>
#include <unordered_set>
#include <unordered_map>
#include "utils/log_adapter.h"
#include "ir/tensor.h"
#include "ps/util.h"
#include "ps/internal/constants.h"
#include "utils/shape_utils.h"
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
#include "ps/core/worker_node.h"
#include "ps/embedding_table_shard_metadata.h"
#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
#include "ps/ps_context.h"
namespace mindspore {
namespace ps {
namespace internal {
class Worker {
public:
static Worker &GetInstance() {
static Worker instance;
return instance;
}
using Callback = std::function<void()>;
using PartitionEmbeddingMessages = std::vector<std::pair<bool, EmbeddingTableLookup>>;
using PartitionKVMessages = std::vector<std::pair<bool, KVMessage>>;
using EmbeddingPartitioner = std::function<void(
const EmbeddingTableLookup &send, PartitionEmbeddingMessages *partition, const std::map<int64_t, int64_t> &attrs)>;
using KVPartitioner =
std::function<void(const KVMessage &send, PartitionKVMessages *partition, const std::map<int64_t, int64_t> &attrs)>;
void Run();
void Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs, const ShapeVector &sizes);
void Pull(const size_t key, void *dev_addr, const size_t size);
size_t SetParamKey(const std::string &param_name);
size_t GetParamKey(const std::string &param_name);
void SetParamInitInServer(const std::string &param_name, bool init_in_server);
bool GetParamInitInServer(const std::string &param_name);
void SetKeyOptimId(size_t key, const std::string &optimizer_name);
void SetOptimInputShapes(size_t key, const ShapeVector &shape);
void AddEmbeddingTable(const Key &key, const size_t &row_count);
void InitPSEmbeddingTable(const size_t &key, const std::vector<size_t> &input_shape,
const std::vector<size_t> &indices_shape, const std::vector<size_t> &output_shape);
void InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor::TensorPtr &tensor);
void DoPSEmbeddingLookup(const Key &key, const std::vector<int> &lookup_ids, std::vector<float> *lookup_result,
int64_t cmd);
void UpdateEmbeddingTable(const std::vector<Key> &keys, const std::vector<int> &lookup_ids,
const std::vector<float> &vals);
bool running() { return running_; }
void Finalize();
private:
Worker() : running_(false), key_cnt_(0) {}
~Worker() = default;
Worker(const Worker &) = delete;
Worker &operator=(const Worker &) = delete;
void Initialize();
bool IsKeyInit(const size_t key);
void AddKeyToServerId(const Key &key);
void AddKeyByHashMod(const Key &key);
void InitPSOptimId(const size_t param_key);
void InitPSOptimInputShapes(const size_t key);
void InitPSParamData(const std::vector<size_t> &keys, void *origin_addr, size_t size);
bool IsReadyForPush(const Key &key);
bool IsReadyForPull(const Key &key);
void PrepareSparseGradient(const size_t begin, const size_t end, const std::unordered_set<int> &distinct_ids,
const std::vector<std::pair<int, float *>> &indice_to_grads, const int *all_indice,
const size_t segment_size, float *gradient, int *indices);
void BuildSparseValue(const std::vector<int> &lengths, const size_t grad_index, const size_t indice_index,
const float *original_data, const float *grads, int *indices, std::vector<float> *reduced_data);
void PushData(const std::vector<Key> &keys, const std::vector<float> &vals, const std::vector<int> &lens = {},
int command = 0, int64_t priority = 0);
void PushSparseData(const std::vector<Key> &keys, const std::vector<float> &vals, const std::vector<int> &lens,
size_t grad_index, size_t indice_index, size_t first_dim_size, size_t outer_dim_size);
void PullData(const std::vector<Key> &keys, std::vector<float> *vals, std::vector<int> *lens = nullptr, int cmd = 0,
int64_t priority = 0);
void LookupIdPartitioner(const EmbeddingTableLookup &send, PartitionEmbeddingMessages *partition,
const std::map<int64_t, int64_t> &attrs);
void SparsePartitioner(const KVMessage &send, PartitionKVMessages *partition,
const std::map<int64_t, int64_t> &attrs);
void RoundRobinPartitioner(const KVMessage &send, PartitionKVMessages *partition,
const std::map<int64_t, int64_t> &attrs);
void WorkerInitEmbeddingPartitioner(const KVMessage &send, std::vector<std::pair<bool, KVMessage>> *partition,
const std::map<int64_t, int64_t> &attrs);
void UpdateEmbeddingPartitioner(const KVMessage &send, PartitionKVMessages *partition,
const std::map<int64_t, int64_t> &attrs);
void BroadcastPartitioner(const KVMessage &send, PartitionKVMessages *partition,
const std::map<int64_t, int64_t> &attrs);
void SendForPush(int cmd, const KVMessage &send, const KVPartitioner &partitioner,
const std::map<int64_t, int64_t> &attrs);
void SendForPull(int cmd, const KVMessage &send, const KVPartitioner &partitioner,
const std::map<int64_t, int64_t> &attrs, std::vector<float> *vals, std::vector<int> *lens);
int64_t server_num_;
bool running_;
std::mutex running_mutex_;
size_t key_cnt_;
std::map<std::string, size_t> param_to_key_;
std::map<size_t, bool> init_keys_;
std::map<size_t, int64_t> key_to_optimId_;
std::map<size_t, std::vector<ShapeVector>> key_to_optim_shapes_;
std::map<std::string, bool> param_to_init_in_server_;
core::WorkerNode worker_node_;
EmbeddingPartitioner lookup_partitioner_;
KVPartitioner sparse_partitioner_;
KVPartitioner round_robin_partitioner_;
KVPartitioner worker_init_embedding_partitioner_;
KVPartitioner update_embedding_partitioner_;
KVPartitioner broadcast_partitioner_;
std::unordered_map<Key, int64_t> key_to_server_id_;
std::unordered_map<Key, size_t> embedding_row_cnt_;
std::unordered_map<Key, std::shared_ptr<std::vector<EmbeddingTableShardMetadata>>> embedding_table_ranges_;
};
static Worker &worker = Worker::GetInstance();
} // namespace internal
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_INTERNAL_WORKER_H_

View File

@ -84,7 +84,7 @@ void DenseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) {
for (size_t i = 0; i < grad_index; i++) {
grad_offset += lengths[i];
}
float *grad_data = values.data() + grad_offset;
float *grad_data = const_cast<float *>(values.data()) + grad_offset;
CHECK_EQ(size, static_cast<size_t>(lengths[grad_index]));
for (size_t i = 0; i < size; i++) {
@ -121,7 +121,7 @@ void SparseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) {
for (size_t i = 0; i < grad_index; i++) {
grad_offset += lengths[i];
}
float *incr_grad_data = values.data() + grad_offset;
float *incr_grad_data = const_cast<float *>(values.data()) + grad_offset;
MS_EXCEPTION_IF_NULL(incr_grad_data);
size_t incr_grad_size = lengths[grad_index] * sizeof(float);
@ -148,7 +148,11 @@ void SparseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) {
for (size_t i = 0; i < indices_index; i++) {
indice_offset += lengths[i];
}
int *incr_indice_data = reinterpret_cast<int *>(values.data()) + indice_offset;
void *incr_indice_data_temp = const_cast<float *>(values.data()) + indice_offset;
int *incr_indice_data = reinterpret_cast<int *>(incr_indice_data_temp);
MS_EXCEPTION_IF_NULL(incr_indice_data);
size_t incr_indice_size = lengths[indices_index];
size_t incr_indice_data_size = incr_indice_size * sizeof(int);
@ -259,7 +263,7 @@ MomentumOptimInfo::MomentumOptimInfo(const AddressPtr &weight, const AddressPtr
}
void MomentumOptimInfo::Update(const Values &values, const Lengths &lens) {
UpdateOptimInputValue<float>(kApplyMomentum, "lr", values.data(), lens);
UpdateOptimInputValue<float>(kApplyMomentum, "lr", const_cast<float *>(values.data()), lens);
}
const size_t SparseOptimInfo::indice_size() const { return indices_offset_; }
@ -303,12 +307,12 @@ SparseAdamOptimInfo::SparseAdamOptimInfo(const AddressPtr &weight, const Address
}
void SparseAdamOptimInfo::Update(const Values &values, const Lengths &lens) {
UpdateOptimInputValue<float>(kSparseAdam, "beta1_power", values.data(), lens);
UpdateOptimInputValue<float>(kSparseAdam, "beta2_power", values.data(), lens);
UpdateOptimInputValue<float>(kSparseAdam, "lr", values.data(), lens);
UpdateOptimInputValue<float>(kSparseAdam, "beta1", values.data(), lens);
UpdateOptimInputValue<float>(kSparseAdam, "beta2", values.data(), lens);
UpdateOptimInputValue<float>(kSparseAdam, "eps", values.data(), lens);
UpdateOptimInputValue<float>(kSparseAdam, "beta1_power", const_cast<float *>(values.data()), lens);
UpdateOptimInputValue<float>(kSparseAdam, "beta2_power", const_cast<float *>(values.data()), lens);
UpdateOptimInputValue<float>(kSparseAdam, "lr", const_cast<float *>(values.data()), lens);
UpdateOptimInputValue<float>(kSparseAdam, "beta1", const_cast<float *>(values.data()), lens);
UpdateOptimInputValue<float>(kSparseAdam, "beta2", const_cast<float *>(values.data()), lens);
UpdateOptimInputValue<float>(kSparseAdam, "eps", const_cast<float *>(values.data()), lens);
}
const AddressPtr &SparseAdamOptimInfo::gradient() {

View File

@ -20,7 +20,7 @@
#include <vector>
#include <string>
#include "backend/kernel_compiler/kernel.h"
#include "ps/common.h"
#include "ps/constants.h"
namespace mindspore {
namespace ps {

View File

@ -129,9 +129,9 @@ OptimizerInfo *MomentumOptimInfoBuilder::BuildInputs(const WeightPtr &weight, co
return nullptr;
}
AddressPtr learning_rate = GenInputAddrPtr<float>(kApplyMomentum, "lr", values.data(), lens);
AddressPtr gradient = GenInputAddrPtr<float>(kApplyMomentum, "grad", values.data(), lens);
AddressPtr momentum = GenInputAddrPtr<float>(kApplyMomentum, "momentum", values.data(), lens);
AddressPtr learning_rate = GenInputAddrPtr<float>(kApplyMomentum, "lr", const_cast<float *>(values.data()), lens);
AddressPtr gradient = GenInputAddrPtr<float>(kApplyMomentum, "grad", const_cast<float *>(values.data()), lens);
AddressPtr momentum = GenInputAddrPtr<float>(kApplyMomentum, "momentum", const_cast<float *>(values.data()), lens);
return new MomentumOptimInfo(weight_addr, accumulate, learning_rate, gradient, momentum);
}
@ -172,14 +172,15 @@ OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight,
return nullptr;
}
AddressPtr beta1_power = GenInputAddrPtr<float>(kSparseAdam, "beta1_power", values.data(), lens);
AddressPtr beta2_power = GenInputAddrPtr<float>(kSparseAdam, "beta2_power", values.data(), lens);
AddressPtr learning_rate = GenInputAddrPtr<float>(kSparseAdam, "lr", values.data(), lens);
AddressPtr beta1 = GenInputAddrPtr<float>(kSparseAdam, "beta1", values.data(), lens);
AddressPtr beta2 = GenInputAddrPtr<float>(kSparseAdam, "beta2", values.data(), lens);
AddressPtr epsilon = GenInputAddrPtr<float>(kSparseAdam, "eps", values.data(), lens);
AddressPtr grad = GenInputAddrPtr<float>(kSparseAdam, "grad", values.data(), lens, inputs_shape);
AddressPtr indices = GenInputAddrPtr<float>(kSparseAdam, "indices", values.data(), lens, inputs_shape);
AddressPtr beta1_power = GenInputAddrPtr<float>(kSparseAdam, "beta1_power", const_cast<float *>(values.data()), lens);
AddressPtr beta2_power = GenInputAddrPtr<float>(kSparseAdam, "beta2_power", const_cast<float *>(values.data()), lens);
AddressPtr learning_rate = GenInputAddrPtr<float>(kSparseAdam, "lr", const_cast<float *>(values.data()), lens);
AddressPtr beta1 = GenInputAddrPtr<float>(kSparseAdam, "beta1", const_cast<float *>(values.data()), lens);
AddressPtr beta2 = GenInputAddrPtr<float>(kSparseAdam, "beta2", const_cast<float *>(values.data()), lens);
AddressPtr epsilon = GenInputAddrPtr<float>(kSparseAdam, "eps", const_cast<float *>(values.data()), lens);
AddressPtr grad = GenInputAddrPtr<float>(kSparseAdam, "grad", const_cast<float *>(values.data()), lens, inputs_shape);
AddressPtr indices =
GenInputAddrPtr<float>(kSparseAdam, "indices", const_cast<float *>(values.data()), lens, inputs_shape);
return new SparseAdamOptimInfo(weight_addr, m, v, beta1_power, beta2_power, learning_rate, beta1, beta2, epsilon,
grad, indices, sharded);
}
@ -218,8 +219,9 @@ OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight,
}
linear->size = weight->size() * sizeof(float);
AddressPtr grad = GenInputAddrPtr<float>(kSparseFtrl, "grad", values.data(), lens, inputs_shape);
AddressPtr indices = GenInputAddrPtr<float>(kSparseFtrl, "indices", values.data(), lens, inputs_shape);
AddressPtr grad = GenInputAddrPtr<float>(kSparseFtrl, "grad", const_cast<float *>(values.data()), lens, inputs_shape);
AddressPtr indices =
GenInputAddrPtr<float>(kSparseFtrl, "indices", const_cast<float *>(values.data()), lens, inputs_shape);
return new SparseFtrlOptimInfo(weight_addr, accum, linear, grad, indices, sharded);
}
} // namespace ps

View File

@ -14,12 +14,10 @@
* limitations under the License.
*/
#include "ps/internal/parameter_server.h"
#include "ps/parameter_server.h"
namespace mindspore {
namespace ps {
namespace internal {
void ParameterServer::Run(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_LOG(INFO) << "PServer starts connecting to scheduler and workers...";
@ -44,8 +42,8 @@ void ParameterServer::Run(const FuncGraphPtr &func_graph) {
}
bool ParameterServer::Init(const FuncGraphPtr &func_graph) {
pserver_num_ = std::strtol(mindspore::common::GetEnv(kEnvWorkerNum).c_str(), nullptr, 10);
worker_num_ = std::strtol(mindspore::common::GetEnv(kEnvPServerNum).c_str(), nullptr, 10);
pserver_num_ = std::strtol(mindspore::common::GetEnv(kEnvPServerNum).c_str(), nullptr, 10);
worker_num_ = std::strtol(mindspore::common::GetEnv(kEnvWorkerNum).c_str(), nullptr, 10);
func_graph_ = func_graph;
handler_.reset(new ServerHandler(this));
handler_->Init();
@ -257,12 +255,21 @@ void ParameterServer::AccumGrad(const Keys &keys, const Values &values, const Le
std::shared_ptr<OptimizerInfo> optim_info = optim_infos_[key];
// Create or update the optimizer info
std::shared_ptr<kernel::ps::PServerKernel> pserver_kernel = optimizers_[key];
if (pserver_kernel == nullptr) {
MS_LOG(EXCEPTION) << "no optimizer found for key " << key << " optim name " << weight_key_to_optims_[key];
if (optim_info == nullptr) {
const std::shared_ptr<OptimizerInfoBuilder> &builder = optim_info_builders_[weight_key_to_optims_[key]];
std::shared_ptr<kernel::ps::PServerKernel> pserver_kernel = optimizers_[key];
if (pserver_kernel == nullptr) {
MS_LOG(EXCEPTION) << "no optimizer found for key " << key << " optim name " << weight_key_to_optims_[key];
}
MS_EXCEPTION_IF_NULL(pserver_kernel);
OptimizerInfo *optim = builder->Build(pserver_kernel, weights_[key], keys, values, lengths,
optim_inputs_shape_[key], worker_num_, is_embedding_[key]);
optim_info.reset(optim);
optim_infos_[key] = optim_info;
} else {
optim_info->Update(values, lengths);
optim_info->Accumulate(values, lengths);
}
MS_EXCEPTION_IF_NULL(pserver_kernel);
optim_infos_[key] = optim_info;
}
grads_accum_counter_[key] += 1;
@ -373,7 +380,7 @@ inline bool ParameterServer::ReadyForPush(const Key &key) {
MS_LOG(EXCEPTION) << "The weights in server is empty. Many reasons could cause this: 1.The Worker didn't send "
"kInitWeightsCmd command. 2.The Server failed to initialize weights.";
}
MS_LOG(INFO) << "the grad_accum_count_:" << grad_accum_count_ << " the weights_:" << weights_.size()
MS_LOG(INFO) << "The grad_accum_count_:" << grad_accum_count_ << " the weights_:" << weights_.size()
<< " the token:" << (tokens_[key] <= 0);
return grad_accum_count_ < weights_.size() && tokens_[key] <= 0;
}
@ -544,11 +551,9 @@ void ParameterServer::ServerHandler::HandleInitWeights(DataPtr data, size_t size
for (int i = 0; i < key_num; i++) {
Key key = input.keys()[i];
size_t data_len = input.len_size() != key_num ? input.values_size() / key_num : input.len()[i];
MS_LOG(DEBUG) << "The data len:" << data_len;
if (!ps_->HasWeight(key)) {
WeightPtr weight_ptr = std::make_shared<std::vector<float>>(data_ptr + pos, data_ptr + (pos + data_len));
MS_LOG(DEBUG) << "The weight ptr:" << *weight_ptr;
MS_EXCEPTION_IF_NULL(weight_ptr);
ps_->InitWeight(key, weight_ptr);
@ -637,7 +642,7 @@ void ParameterServer::ServerHandler::HandleCheckReadyForPush(DataPtr data, size_
input.ParseFromArray(data.get(), size);
const Key &key = input.keys()[0];
bool ready = ps_->ReadyForPush(key);
MS_LOG(INFO) << "the ready is:" << ready;
MS_LOG(INFO) << "The ready is:" << ready;
KVMessage res_data;
res_data.add_keys(key);
res_data.add_values(ready);
@ -671,7 +676,6 @@ void ParameterServer::ServerHandler::HandleEmbeddingLookup(DataPtr data, size_t
EmbeddingTableLookup input;
input.ParseFromArray(data.get(), size);
const Key &key = input.key();
MS_LOG(DEBUG) << "The key is:" << key;
KVMessage res_data;
std::vector<Key> keys = {input.keys().begin(), input.keys().end()};
@ -701,6 +705,5 @@ void ParameterServer::ServerHandler::HandleFinalize(DataPtr data, size_t size, V
MS_EXCEPTION_IF_NULL(res);
ps_->Finalize();
}
} // namespace internal
} // namespace ps
} // namespace mindspore

File diff suppressed because it is too large Load Diff

View File

@ -145,7 +145,6 @@ const size_t &PsCacheManager::QueryHashTableSize(const std::string &param_name)
void PsCacheManager::Initialize() {
MS_LOG(INFO) << "PS cache initialize.";
if (!worker.running()) {
Util::SetInternalEnvVar();
worker.Run();
}
embedding_device_cache_ = std::make_shared<EmbeddingDeviceCache>(batch_elements_, vocab_cache_size_);
@ -177,22 +176,19 @@ void PsCacheManager::InitParameterServer() {
for (const auto &item : hash_tables_) {
const auto &param_name = item.first;
size_t key = worker.SetParamKey(param_name);
std::vector<size_t> keys{key, key, key, key, key, key};
std::vector<float> values{
SizeToFloat(item.second.vocab_size), SizeToFloat(item.second.embedding_size), 1, 1, 1, 1, 1};
std::vector<int64_t> lens{2, 2, 3};
const auto &hash_table_info = item.second;
const auto &param_init_info = hash_table_info.param_init_info_;
if (param_init_info.param_type_ == kWeight) {
lens.push_back(1);
} else if (param_init_info.param_type_ == kAccumulation) {
lens.push_back(2);
}
values.push_back(param_init_info.init_val_);
lens.push_back(param_init_info.global_seed_);
lens.push_back(param_init_info.op_seed_);
std::vector<size_t> input_shape = {item.second.vocab_size, item.second.embedding_size};
std::vector<size_t> indices_shape = {1, 1};
std::vector<size_t> output_shape = {1, 1, 1};
ParamInitInfoMessage info;
info.set_param_type(param_init_info.param_type_);
info.set_init_val(param_init_info.init_val_);
info.set_global_seed(param_init_info.global_seed_);
info.set_op_seed(param_init_info.op_seed_);
// if worker role
worker.InitPSEmbeddingTable(keys, values, lens);
worker.InitPSEmbeddingTable(key, input_shape, indices_shape, output_shape, info);
}
finish_init_parameter_server_ = true;
@ -245,7 +241,7 @@ void PsCacheManager::AllocMemForHashTable() {
}
void PsCacheManager::SetLocalIdRank() {
auto worker_num = ::ps::NumWorkers();
auto worker_num = PSContext::instance()->initial_worker_num();
auto local_shard_size = FloatToInt(std::ceil(SizeToFloat(vocab_size_) / worker_num));
vocab_cache_size_diff_ = local_shard_size - SizeToInt(vocab_cache_size_);
emb_table_slice_bounds_.first = local_shard_size * rank_id_;
@ -829,8 +825,8 @@ bool PsCacheManager::HashSwapHostToServer(size_t key, const HashTableInfo &hash_
if (swap_indices_size == 0) {
return true;
}
::ps::SArray<int> lookup_ids(swap_indices_size, 0);
::ps::SArray<float> swap_out_data;
std::vector<int> lookup_ids(swap_indices_size, 0);
std::vector<float> swap_out_data;
auto embedding_size = hash_info.embedding_size;
swap_out_data.resize(swap_indices_size * embedding_size);
auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get());
@ -857,22 +853,21 @@ bool PsCacheManager::HashSwapServerToHost(size_t key, const HashTableInfo &hash_
}
auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get());
auto embedding_size = hash_info.embedding_size;
::ps::SArray<int> lengths{swap_indices_size};
::ps::SArray<float> lookup_result(swap_indices_size * embedding_size, 0);
::ps::SArray<int> lookup_ids(swap_indices_size, 0);
std::vector<float> lookup_result(swap_indices_size * embedding_size, 0);
std::vector<int> lookup_ids(swap_indices_size, 0);
auto copy_len = swap_indices_size * sizeof(int);
auto ret = memcpy_s(lookup_ids.data(), copy_len, server_to_host_ids, copy_len);
if (ret != EOK) {
MS_LOG(ERROR) << "Lookup id memcpy failed.";
return false;
}
worker.DoPSEmbeddingLookup({key}, lookup_ids, lengths, &lookup_result, mindspore::ps::kEmbeddingLookupCmd);
worker.DoPSEmbeddingLookup(key, lookup_ids, &lookup_result, mindspore::ps::kEmbeddingLookupCmd);
RETURN_IF_FALSE(InsertHostHashTable(embedding_size, IntToSize(swap_indices_size), server_to_host_index,
lookup_result.data(), host_hash_table_addr));
return true;
}
bool PsCacheManager::HashSwapDeviceOut(int *swap_out_index, ::ps::SArray<float> *swap_out_data,
bool PsCacheManager::HashSwapDeviceOut(int *swap_out_index, std::vector<float> *swap_out_data,
const HashTableInfo &hash_info) {
MS_ERROR_IF_NULL(swap_out_index);
MS_ERROR_IF_NULL(swap_out_data);
@ -912,16 +907,15 @@ bool PsCacheManager::HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, cons
auto cache_vocab_size = hash_info.cache_vocab_size;
auto embedding_size = hash_info.embedding_size;
// Get id embs by swap_in_ids in host(Pipeline with hash swap-out in device).
::ps::SArray<int> lengths{swap_in_ids_size};
::ps::SArray<float> lookup_result(swap_in_ids_size * embedding_size, 0);
::ps::SArray<int> lookup_ids(swap_in_ids_size, 0);
std::vector<float> lookup_result(swap_in_ids_size * embedding_size, 0);
std::vector<int> lookup_ids(swap_in_ids_size, 0);
auto copy_len = swap_in_ids_size * sizeof(int);
auto ret = memcpy_s(lookup_ids.data(), copy_len, swap_in_ids, copy_len);
if (ret != EOK) {
MS_LOG(ERROR) << "Lookup id memcpy failed.";
return false;
}
worker.DoPSEmbeddingLookup({key}, lookup_ids, lengths, &lookup_result, mindspore::ps::kEmbeddingLookupCmd);
worker.DoPSEmbeddingLookup(key, lookup_ids, &lookup_result, mindspore::ps::kEmbeddingLookupCmd);
// Hash swap-in in device.
RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice(
embedding_device_cache_->hash_swap_value_addr_, lookup_result.data(),
@ -934,7 +928,7 @@ bool PsCacheManager::HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, cons
return true;
}
bool PsCacheManager::UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_data, int *swap_out_ids, size_t key) {
bool PsCacheManager::UpdataEmbeddingTable(const std::vector<float> &swap_out_data, int *swap_out_ids, size_t key) {
MS_ERROR_IF_NULL(embedding_device_cache_);
MS_ERROR_IF_NULL(embedding_device_cache_->cache_);
MS_ERROR_IF_NULL(swap_out_ids);
@ -942,7 +936,7 @@ bool PsCacheManager::UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_da
if (swap_out_ids_size == 0) {
return true;
}
::ps::SArray<int> lookup_ids(swap_out_ids_size, 0);
std::vector<int> lookup_ids(swap_out_ids_size, 0);
auto copy_len = swap_out_ids_size * sizeof(int);
auto ret = memcpy_s(lookup_ids.data(), copy_len, swap_out_ids, copy_len);
if (ret != EOK) {
@ -994,8 +988,8 @@ bool PsCacheManager::SyncHostEmbeddingTable() {
continue;
}
auto key = worker.GetParamKey(item.first);
::ps::SArray<int> lookup_ids(swap_indices_lens, 0);
::ps::SArray<float> swap_out_data;
std::vector<int> lookup_ids(swap_indices_lens, 0);
std::vector<float> swap_out_data;
auto embedding_size = hash_info.embedding_size;
swap_out_data.resize(swap_indices_lens * embedding_size);
auto host_hash_table_addr = hash_info.host_address.get();
@ -1038,8 +1032,8 @@ bool PsCacheManager::SyncDeviceEmbeddingTable() {
continue;
}
auto key = worker.GetParamKey(item.first);
::ps::SArray<int> lookup_ids(swap_indices_lens, 0);
::ps::SArray<float> swap_out_data;
std::vector<int> lookup_ids(swap_indices_lens, 0);
std::vector<float> swap_out_data;
auto embedding_size = hash_info.embedding_size;
swap_out_data.resize(swap_indices_lens * embedding_size);
std::unique_ptr<float[]> device_hash_table_addr_tmp =

View File

@ -29,9 +29,9 @@
#include "backend/kernel_compiler/kernel.h"
#include "utils/shape_utils.h"
#include "ir/tensor.h"
#include "ps/ps.h"
#include "ps/common.h"
#include "ps/constants.h"
#include "ps/worker.h"
#include "ps/ps_context.h"
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
#include "ps/ps_cache/embedding_hash_map.h"
#include "ps/ps_cache/ps_cache_factory.h"
@ -155,7 +155,7 @@ class PsCacheManager {
bool ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device, int *hash_index);
bool ParseHostDataHostToDevice(size_t id);
bool ParseHostDataDeviceToHost();
bool HashSwapDeviceOut(int *swap_out_index, ::ps::SArray<float> *swap_out_data, const HashTableInfo &hash_info);
bool HashSwapDeviceOut(int *swap_out_index, std::vector<float> *swap_out_data, const HashTableInfo &hash_info);
bool HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, const HashTableInfo &hash_info, size_t key);
bool HashSwapHostToDevice(const HashTableInfo &hash_info);
bool HashSwapDeviceToHost(const HashTableInfo &hash_info);
@ -165,7 +165,7 @@ class PsCacheManager {
float *hash_table_addr);
bool LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr,
const int *indices_addr, float *output_addr);
bool UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_data, int *swap_out_ids, size_t key);
bool UpdataEmbeddingTable(const std::vector<float> &swap_out_data, int *swap_out_ids, size_t key);
void LookUpTableTask(size_t indices_lens, size_t outer_dim_size, size_t first_dim_size, const float *input_addr,
const int *indices_addr, float *output_addr);
bool CheckFinishInsertInitInfo() const;

View File

@ -48,10 +48,10 @@ void PSContext::SetPSEnable(bool enabled) {
MS_LOG(WARNING) << "MS_ROLE is " << ms_role << ", which is invalid.";
}
worker_num_ = std::strtol(common::GetEnv("MS_WORKER_NUM").c_str(), nullptr, 10);
server_num_ = std::strtol(common::GetEnv("MS_SERVER_NUM").c_str(), nullptr, 10);
scheduler_host_ = common::GetEnv("MS_SCHED_HOST");
scheduler_port_ = std::strtol(common::GetEnv("MS_SCHED_PORT").c_str(), nullptr, 10);
worker_num_ = std::strtol(common::GetEnv(kEnvWorkerNum).c_str(), nullptr, 10);
server_num_ = std::strtol(common::GetEnv(kEnvPServerNum).c_str(), nullptr, 10);
scheduler_host_ = common::GetEnv(kEnvSchedulerHost);
scheduler_port_ = std::strtol(common::GetEnv(kEnvSchedulerPort).c_str(), nullptr, 10);
} else {
MS_LOG(INFO) << "PS mode is disabled.";
is_worker_ = false;

View File

@ -19,6 +19,7 @@
#include <string>
#include <memory>
#include "ps/constants.h"
namespace mindspore {
namespace ps {

View File

@ -15,13 +15,16 @@
*/
#include "ps/scheduler.h"
#include "ps/ps.h"
namespace mindspore {
namespace ps {
void Scheduler::Run() {
::ps::Start(0);
::ps::Finalize(0, true);
core::ClusterMetadata::instance()->Init(
PSContext::instance()->initial_worker_num(), PSContext::instance()->initial_server_num(),
PSContext::instance()->scheduler_host(), PSContext::instance()->scheduler_port());
scheduler_node_.Start();
scheduler_node_.Finish();
scheduler_node_.Stop();
exit(1);
}
} // namespace ps

View File

@ -16,6 +16,11 @@
#ifndef MINDSPORE_CCSRC_PS_SCHEDULER_H_
#define MINDSPORE_CCSRC_PS_SCHEDULER_H_
#include "ps/core/scheduler_node.h"
#include "ps/util.h"
#include "ps/ps_context.h"
namespace mindspore {
namespace ps {
class Scheduler {
@ -32,6 +37,7 @@ class Scheduler {
~Scheduler() = default;
Scheduler(const Scheduler &) = delete;
Scheduler &operator=(const Scheduler &) = delete;
core::SchedulerNode scheduler_node_;
};
} // namespace ps
} // namespace mindspore

View File

@ -17,7 +17,7 @@
#include "ps/util.h"
#include <unordered_map>
#include <vector>
#include "ps/common.h"
#include "ps/constants.h"
#include "ps/ps_context.h"
#include "utils/ms_utils.h"
@ -46,50 +46,10 @@ std::unordered_map<int64_t, std::string> Util::id_to_optimizer_nodes{
{3, kSparseFtrlOp},
};
bool Util::IsParamServerMode() { return PSContext::instance()->is_ps_mode(); }
bool Util::IsRoleOfWorker() { return PSContext::instance()->is_worker(); }
bool Util::IsRoleOfPServer() { return PSContext::instance()->is_server(); }
bool Util::IsRoleOfScheduler() { return PSContext::instance()->is_scheduler(); }
void Util::SetInternalEnvVar() {
if (IsParamServerMode()) {
auto comm_type = common::GetEnv(kEnvCommType);
if (!comm_type.empty()) {
(void)common::SetEnv(kDmlcCommType, comm_type.c_str());
}
auto interface = common::GetEnv(kEnvInterface);
if (!interface.empty()) {
(void)common::SetEnv(kDmlcInterface, interface.c_str());
}
auto server_num = common::GetEnv(kEnvPServerNum);
if (!server_num.empty()) {
(void)common::SetEnv(kDmlcPServerNum, server_num.c_str());
}
auto worker_num = common::GetEnv(kEnvWorkerNum);
if (!worker_num.empty()) {
(void)common::SetEnv(kDmlcWorkerNum, worker_num.c_str());
}
if (IsRoleOfScheduler()) {
(void)common::SetEnv(kDmlcRole, kRoleOfScheduler);
} else if (IsRoleOfPServer()) {
(void)common::SetEnv(kDmlcRole, kRoleOfPServer);
} else if (IsRoleOfWorker()) {
(void)common::SetEnv(kDmlcRole, kRoleOfWorker);
}
auto scheduler_host = common::GetEnv(kEnvSchedulerHost);
if (!scheduler_host.empty()) {
(void)common::SetEnv(kDmlcSchedulerHost, scheduler_host.c_str());
}
auto scheduler_port = common::GetEnv(kEnvSchedulerPort);
if (!scheduler_port.empty()) {
(void)common::SetEnv(kDmlcSchedulerPort, scheduler_port.c_str());
}
}
}
int64_t Util::optimizer_id(std::string name) {
if (optimizer_to_ids.count(name) > 0) {
return optimizer_to_ids[name];

View File

@ -37,11 +37,8 @@ struct ParamInitInfo {
class Util {
public:
static bool IsParamServerMode();
static bool IsRoleOfWorker();
static bool IsRoleOfPServer();
static bool IsRoleOfScheduler();
static void SetInternalEnvVar();
static int64_t optimizer_id(std::string name);
static std::string optimizer_name(int64_t id);
static std::string optimizer_node_name(int64_t id);

View File

@ -14,11 +14,10 @@
* limitations under the License.
*/
#include "ps/internal/worker.h"
#include "ps/worker.h"
namespace mindspore {
namespace ps {
namespace internal {
void Worker::Run() {
std::lock_guard<std::mutex> lock(running_mutex_);
core::ClusterMetadata::instance()->Init(
@ -198,7 +197,8 @@ void Worker::AddEmbeddingTable(const Key &key, const size_t &row_count) {
}
void Worker::InitPSEmbeddingTable(const size_t &key, const std::vector<size_t> &input_shape,
const std::vector<size_t> &indices_shape, const std::vector<size_t> &output_shape) {
const std::vector<size_t> &indices_shape, const std::vector<size_t> &output_shape,
const ParamInitInfoMessage &info) {
bool has_init = IsKeyInit(key);
if (has_init) {
MS_LOG(DEBUG) << "The key embedding table of key " << key << " is initialized.";
@ -210,6 +210,7 @@ void Worker::InitPSEmbeddingTable(const size_t &key, const std::vector<size_t> &
*embedding_table_meta.mutable_input_shape() = {input_shape.begin(), input_shape.end()};
*embedding_table_meta.mutable_indices_shape() = {indices_shape.begin(), indices_shape.end()};
*embedding_table_meta.mutable_output_shape() = {output_shape.begin(), output_shape.end()};
*embedding_table_meta.mutable_info() = info;
std::string kv_data = embedding_table_meta.SerializeAsString();
@ -295,19 +296,18 @@ void Worker::DoPSEmbeddingLookup(const Key &key, const std::vector<int> &lookup_
int64_t single_id_len = SizeToLong(lookup_result->size() / lookup_ids.size());
std::unordered_map<Key, std::shared_ptr<std::pair<float *, int64_t>>> id_addr_map;
std::shared_ptr<std::vector<float>> values = std::make_shared<std::vector<float>>();
int64_t value_offset = 0;
for (size_t i = 0; i < resp.size(); ++i) {
KVMessage message;
message.ParseFromArray(resp.at(i)->data(), resp.at(i)->size());
int64_t offset = 0;
values->clear();
for (auto j = 0; j < message.values_size(); j++) {
values->push_back(message.values(j));
}
MS_LOG(DEBUG) << "the embedding resp:" << values;
MS_LOG(DEBUG) << "The embedding resp:" << values;
for (auto k = 0; k < message.keys_size(); k++) {
const Key &key = message.keys(k);
float *addr = values->data() + offset;
offset += single_id_len;
float *addr = values->data() + value_offset;
value_offset += single_id_len;
id_addr_map[key] = std::make_shared<std::pair<float *, int64_t>>(std::make_pair(addr, single_id_len));
}
}
@ -969,6 +969,5 @@ void Worker::SendForPull(int cmd, const KVMessage &send, const KVPartitioner &pa
}
}
}
} // namespace internal
} // namespace ps
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -25,24 +25,38 @@
#include <functional>
#include <algorithm>
#include <map>
#include "ps/ps.h"
#include <mutex>
#include <unordered_set>
#include <unordered_map>
#include "utils/log_adapter.h"
#include "ir/tensor.h"
#include "ps/util.h"
#include "ps/common.h"
#include "ps/worker_proxy.h"
#include "ps/constants.h"
#include "utils/shape_utils.h"
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
#include "ps/core/worker_node.h"
#include "ps/embedding_table_shard_metadata.h"
#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
#include "ps/ps_context.h"
namespace mindspore {
namespace ps {
template <typename T>
class Worker {
public:
static Worker &GetInstance() {
static Worker instance;
return instance;
}
using Callback = std::function<void()>;
using PartitionEmbeddingMessages = std::vector<std::pair<bool, EmbeddingTableLookup>>;
using PartitionKVMessages = std::vector<std::pair<bool, KVMessage>>;
using EmbeddingPartitioner = std::function<void(
const EmbeddingTableLookup &send, PartitionEmbeddingMessages *partition, const std::map<int64_t, int64_t> &attrs)>;
using KVPartitioner =
std::function<void(const KVMessage &send, PartitionKVMessages *partition, const std::map<int64_t, int64_t> &attrs)>;
void Run();
void Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs, const ShapeVector &sizes);
@ -53,340 +67,89 @@ class Worker {
bool GetParamInitInServer(const std::string &param_name);
void SetKeyOptimId(size_t key, const std::string &optimizer_name);
void SetOptimInputShapes(size_t key, const ShapeVector &shape);
void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count);
void InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vector<T> shapes, const ShapeVector &sizes);
void AddEmbeddingTable(const Key &key, const size_t &row_count);
void InitPSEmbeddingTable(const size_t &key, const std::vector<size_t> &input_shape,
const std::vector<size_t> &indices_shape, const std::vector<size_t> &output_shape,
const ParamInitInfoMessage &info);
void InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor::TensorPtr &tensor);
void DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
const ::ps::SArray<int> &lens, ::ps::SArray<T> *lookup_result, int64_t cmd);
void UpdateEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
const ::ps::SArray<T> &vals);
void DoPSEmbeddingLookup(const Key &key, const std::vector<int> &lookup_ids, std::vector<float> *lookup_result,
int64_t cmd);
void UpdateEmbeddingTable(const std::vector<Key> &keys, const std::vector<int> &lookup_ids,
const std::vector<float> &vals);
bool running() { return running_; }
void Finalize();
private:
Worker() : kv_worker_(nullptr), running_(false), key_cnt_(0) {}
Worker() : running_(false), key_cnt_(0) {}
~Worker() = default;
Worker(const Worker &) = delete;
Worker &operator=(const Worker &) = delete;
void Initialize();
bool IsKeyInit(const size_t key);
void AddKeyToServerId(const Key &key);
void AddKeyByHashMod(const Key &key);
void InitPSOptimId(const size_t param_key);
void InitPSOptimInputShapes(const size_t key);
void InitPSParamData(const std::vector<size_t> &keys, void *origin_addr, size_t size);
static void EmbeddingLookupIdSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &ranges,
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced) {}
bool IsReadyForPush(const Key &key);
bool IsReadyForPull(const Key &key);
void PrepareSparseGradient(const size_t begin, const size_t end, const std::unordered_set<int> &distinct_ids,
const std::vector<std::pair<int, float *>> &indice_to_grads, const int *all_indice,
const size_t segment_size, float *gradient, int *indices);
void BuildSparseValue(const std::vector<int> &lengths, const size_t grad_index, const size_t indice_index,
const float *original_data, const float *grads, int *indices, std::vector<float> *reduced_data);
std::shared_ptr<WorkerProxy<T>> kv_worker_;
void PushData(const std::vector<Key> &keys, const std::vector<float> &vals, const std::vector<int> &lens = {},
int command = 0, int64_t priority = 0);
void PushSparseData(const std::vector<Key> &keys, const std::vector<float> &vals, const std::vector<int> &lens,
size_t grad_index, size_t indice_index, size_t first_dim_size, size_t outer_dim_size);
void PullData(const std::vector<Key> &keys, std::vector<float> *vals, std::vector<int> *lens = nullptr, int cmd = 0,
int64_t priority = 0);
void LookupIdPartitioner(const EmbeddingTableLookup &send, PartitionEmbeddingMessages *partition,
const std::map<int64_t, int64_t> &attrs);
void SparsePartitioner(const KVMessage &send, PartitionKVMessages *partition,
const std::map<int64_t, int64_t> &attrs);
void RoundRobinPartitioner(const KVMessage &send, PartitionKVMessages *partition,
const std::map<int64_t, int64_t> &attrs);
void WorkerInitEmbeddingPartitioner(const KVMessage &send, std::vector<std::pair<bool, KVMessage>> *partition,
const std::map<int64_t, int64_t> &attrs);
void UpdateEmbeddingPartitioner(const KVMessage &send, PartitionKVMessages *partition,
const std::map<int64_t, int64_t> &attrs);
void BroadcastPartitioner(const KVMessage &send, PartitionKVMessages *partition,
const std::map<int64_t, int64_t> &attrs);
void SendForPush(int cmd, const KVMessage &send, const KVPartitioner &partitioner,
const std::map<int64_t, int64_t> &attrs);
void SendForPull(int cmd, const KVMessage &send, const KVPartitioner &partitioner,
const std::map<int64_t, int64_t> &attrs, std::vector<float> *vals, std::vector<int> *lens);
int64_t server_num_;
bool running_;
std::mutex running_mutex_;
size_t key_cnt_;
std::map<std::string, size_t> param_to_key_;
std::map<size_t, bool> init_keys_;
std::map<size_t, int64_t> key_to_optimId_;
std::map<size_t, std::vector<ShapeVector>> key_to_optim_shapes_;
std::map<std::string, bool> param_to_init_in_server_;
core::WorkerNode worker_node_;
EmbeddingPartitioner lookup_partitioner_;
KVPartitioner sparse_partitioner_;
KVPartitioner round_robin_partitioner_;
KVPartitioner worker_init_embedding_partitioner_;
KVPartitioner update_embedding_partitioner_;
KVPartitioner broadcast_partitioner_;
std::unordered_map<Key, int64_t> key_to_server_id_;
std::unordered_map<Key, size_t> embedding_row_cnt_;
std::unordered_map<Key, std::shared_ptr<std::vector<EmbeddingTableShardMetadata>>> embedding_table_ranges_;
};
template <typename T>
void Worker<T>::Run() {
if (running_) {
MS_LOG(INFO) << "'Worker is already running.";
return;
}
MS_LOG(INFO) << "Worker starts connecting to scheduler and server...";
::ps::Start(0);
MS_LOG(INFO) << "Worker connected successfully.";
if (!::ps::IsWorker()) {
MS_LOG(EXCEPTION) << "The role is not worker.";
}
kv_worker_ = std::make_shared<WorkerProxy<T>>(0, 0, 1, 2);
running_ = true;
}
template <typename T>
void Worker<T>::Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs, const ShapeVector &sizes) {
if (keys.size() == 0) {
MS_LOG(EXCEPTION) << "key size should be greater than zero";
}
if (key_to_optimId_.count(keys[0]) == 0) {
MS_LOG(EXCEPTION) << "no optim id found for key" << keys[0];
}
Key key = keys[0];
int64_t optim_id = key_to_optimId_[key];
bool is_sparse = false;
if (optim_id == 1 || optim_id == 2 || optim_id == 3) {
is_sparse = true;
}
int64_t grad_index = -1;
int64_t indice_index = -1;
// Sparse adam gradient
if (optim_id == 1 || optim_id == 2) {
grad_index = 6;
indice_index = 7;
// Sparse ftrl gradient
} else if (optim_id == 3) {
grad_index = 0;
indice_index = 1;
}
size_t total_size = std::accumulate(sizes.begin(), sizes.end(), 0, std::plus<int64_t>());
::ps::SArray<T> total_buffer(total_size, 0);
size_t offset = 0;
size_t dst_size = 0;
size_t src_size = 0;
for (size_t i = 0; i < sizes.size(); i++) {
void *dst_data = total_buffer.data() + offset / sizeof(T);
void *src_data = reinterpret_cast<void *>(addrs[i]);
MS_EXCEPTION_IF_NULL(dst_data);
MS_EXCEPTION_IF_NULL(src_data);
dst_size = sizes[i] * sizeof(T);
src_size = sizes[i] * sizeof(T);
auto ret = memcpy_s(dst_data, dst_size, src_data, src_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
return;
}
offset += sizes[i] * sizeof(T);
}
while (!kv_worker_->IsReadyForPush(keys[0])) {
continue;
}
std::vector<int> sizes_int;
(void)std::transform(sizes.begin(), sizes.end(), std::back_inserter(sizes_int),
[](const int64_t &value) { return static_cast<int>(value); });
if (!is_sparse) {
kv_worker_->PushData(::ps::SArray<::ps::Key>(keys), total_buffer, ::ps::SArray<int>(sizes_int));
} else {
std::vector<int64_t> &var_shape = key_to_optim_shapes_[key][0];
int64_t first_dim_size = var_shape[0];
int64_t outer_dim_size = std::accumulate(var_shape.begin() + 1, var_shape.end(), 1, std::multiplies<int64_t>());
kv_worker_->PushSparseData(::ps::SArray<::ps::Key>(keys), total_buffer, ::ps::SArray<int>(sizes_int), grad_index,
indice_index, first_dim_size, outer_dim_size);
}
}
template <typename T>
void Worker<T>::Pull(const size_t key, void *dev_addr, const size_t size) {
MS_EXCEPTION_IF_NULL(dev_addr);
::ps::SArray<T> variables(size / sizeof(T), 0);
while (!kv_worker_->IsReadyForPull(key)) {
continue;
}
kv_worker_->PullData({key}, &variables);
size_t dst_size = size;
size_t src_size = size;
auto ret = memcpy_s(dev_addr, dst_size, variables.data(), src_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
return;
}
}
template <typename T>
void Worker<T>::DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
const ::ps::SArray<int> &lens, ::ps::SArray<T> *lookup_result, int64_t cmd) {
MS_EXCEPTION_IF_NULL(lookup_result);
kv_worker_->EmbeddingLookup(keys, lookup_ids, lens, lookup_result, cmd);
}
template <typename T>
void Worker<T>::UpdateEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
const ::ps::SArray<T> &vals) {
kv_worker_->UpdateEmbeddingTable(keys, lookup_ids, vals);
}
template <typename T>
void Worker<T>::Finalize() {
if (running_) {
MS_LOG(INFO) << "Worker starts finalizing...";
kv_worker_->Finalize();
kv_worker_.reset();
running_ = false;
MS_LOG(INFO) << "Worker finalized successfully.";
}
}
template <typename T>
void Worker<T>::InitPSParamData(const std::vector<size_t> &keys, void *origin_addr, size_t size) {
MS_EXCEPTION_IF_NULL(origin_addr);
::ps::SArray<T> addr(reinterpret_cast<T *>(origin_addr), size / sizeof(T));
::ps::SArray<::ps::Key> key(keys);
::ps::SArray<int> lens;
lens.push_back(addr.size());
kv_worker_->PushData(key, addr, lens, kInitWeightsCmd);
init_keys_[key[0]] = true;
}
template <typename T>
void Worker<T>::SetOptimInputShapes(size_t key, const ShapeVector &shape) {
if (key_to_optim_shapes_.find(key) == key_to_optim_shapes_.end()) {
key_to_optim_shapes_[key] = {shape};
} else {
key_to_optim_shapes_[key].push_back(shape);
}
}
template <typename T>
void Worker<T>::InitPSOptimInputShapes(const size_t key) {
::ps::SArray<::ps::Key> keys;
::ps::SArray<int> shape_len;
::ps::SArray<T> all_shape;
std::vector<ShapeVector> shapes = key_to_optim_shapes_[key];
for (auto shape : shapes) {
keys.push_back(key);
if (shape.size() == 0) {
shape_len.push_back(1);
all_shape.push_back(1);
} else {
shape_len.push_back(SizeToLong(shape.size()));
for (auto dim : shape) {
all_shape.push_back(static_cast<T>(dim));
}
}
}
MS_LOG(INFO) << "keys:" << keys;
MS_LOG(INFO) << "shape_len:" << shape_len;
MS_LOG(INFO) << "all_shape:" << all_shape;
if (!init_keys_[key]) {
init_keys_[key] = true;
}
kv_worker_->PushData(keys, all_shape, shape_len, kInitOptimInputsShapeCmd);
}
template <typename T>
bool Worker<T>::IsKeyInit(const size_t key) {
if (init_keys_.find(key) == init_keys_.end() || !init_keys_[key]) {
return false;
}
return true;
}
template <typename T>
size_t Worker<T>::SetParamKey(const std::string &param_name) {
size_t key = UINT64_MAX;
if (param_to_key_.count(param_name)) {
key = param_to_key_[param_name];
MS_LOG(INFO) << param_name << " key is already set: key value is " << key;
} else {
key = key_cnt_++;
param_to_key_[param_name] = key;
MS_LOG(INFO) << "Set key " << key << " for parameter " << param_name;
}
return key;
}
template <typename T>
void Worker<T>::SetParamInitInServer(const std::string &param_name, bool init_in_server) {
MS_LOG(INFO) << "Set parameter " << param_name << " init_in_server:" << init_in_server;
param_to_init_in_server_[param_name] = init_in_server;
}
template <typename T>
bool Worker<T>::GetParamInitInServer(const std::string &param_name) {
if (param_to_init_in_server_.count(param_name) == 0) {
return false;
}
return param_to_init_in_server_[param_name];
}
template <typename T>
size_t Worker<T>::GetParamKey(const std::string &param_name) {
size_t key = kInvalidKey;
if (param_to_key_.find(param_name) != param_to_key_.end()) {
key = param_to_key_[param_name];
MS_LOG(DEBUG) << "Get key of parameter " << param_name << " key is " << key;
}
return key;
}
template <typename T>
void Worker<T>::SetKeyOptimId(size_t key, const std::string &optimizer_name) {
key_to_optimId_[key] = Util::optimizer_id(optimizer_name);
}
template <typename T>
void Worker<T>::InitPSOptimId(const size_t param_key) {
if (key_to_optimId_.count(param_key) == 0) {
MS_LOG(EXCEPTION) << "Can't find optimizer id of parameter key " << param_key;
}
int64_t optim_id = key_to_optimId_[param_key];
::ps::SArray<::ps::Key> keys = {param_key};
::ps::SArray<T> optim_id_vals = {static_cast<T>(optim_id)};
::ps::SArray<int> optim_id_lens = {optim_id_vals.size()};
kv_worker_->PushData(keys, optim_id_vals, optim_id_lens, kInitWeightToOptimIdCmd);
}
template <typename T>
void Worker<T>::InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vector<T> shapes, const ShapeVector &sizes) {
bool has_init = IsKeyInit(keys[0]);
if (has_init) {
MS_LOG(DEBUG) << "The key embedding table of key " << keys[0] << " is initialized.";
return;
}
::ps::SArray<T> shapes_val;
for (auto dim : shapes) {
shapes_val.push_back(dim);
}
std::vector<int> sizes_int;
(void)std::transform(sizes.begin(), sizes.end(), std::back_inserter(sizes_int),
[](const int64_t &value) { return static_cast<int>(value); });
kv_worker_->Wait(
kv_worker_->InitEmbeddingTable(::ps::SArray<::ps::Key>(keys), shapes_val, ::ps::SArray<int>(sizes_int)));
}
template <typename T>
void Worker<T>::InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor::TensorPtr &tensor) {
MS_EXCEPTION_IF_NULL(tensor);
MS_EXCEPTION_IF_NULL(input_node);
auto pk_node = input_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(pk_node);
const std::string &param_name = pk_node->fullname_with_scope();
void *param_data = tensor->data_c();
size_t param_size = LongToSize(tensor->data().nbytes());
size_t param_key = GetParamKey(param_name);
if (param_key == kInvalidKey) {
MS_LOG(DEBUG) << "Parameter " << param_name << " has no key assigned.";
return;
}
bool init_in_server = false;
auto param_info_ptr = pk_node->param_info();
if (param_info_ptr != nullptr && param_info_ptr->init_in_server()) {
init_in_server = true;
}
SetParamInitInServer(param_name, init_in_server);
bool init = IsKeyInit(param_key);
if (!init) {
MS_LOG(INFO) << "Init paramter and optimizer in parameter server side for " << param_name
<< ", whether init in server: " << init_in_server;
kv_worker_->AddKeyToServerId(param_key);
if (!PsDataPrefetch::GetInstance().cache_enable()) {
if (!init_in_server) {
if (param_size > INT_MAX) {
MS_LOG(EXCEPTION) << "PS mode max weight size is " << INT_MAX << ", " << param_name << " size is "
<< param_size;
}
InitPSParamData({param_key}, param_data, param_size);
}
InitPSOptimId(param_key);
InitPSOptimInputShapes(param_key);
}
}
}
template <typename T>
void Worker<T>::AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count) {
bool has_init = IsKeyInit(key);
if (has_init) {
return;
}
kv_worker_->AddEmbeddingTable(key, row_count);
}
static Worker<float> &worker = Worker<float>::GetInstance();
static Worker &worker = Worker::GetInstance();
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_WORKER_H_

View File

@ -1,873 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_WORKER_PROXY_H_
#define MINDSPORE_CCSRC_PS_WORKER_PROXY_H_
#include <map>
#include <numeric>
#include <functional>
#include <unordered_map>
#include <unordered_set>
#include <algorithm>
#include <utility>
#include <memory>
#include <vector>
#include "ps/ps.h"
#include "ps/util.h"
#include "backend/kernel_compiler/common_utils.h"
#include "ps/ps_context.h"
namespace mindspore {
namespace ps {
template <typename T>
class WorkerProxy : public ::ps::KVWorker<T> {
public:
using Worker = ::ps::KVWorker<T>;
using Callback = std::function<void()>;
using SlicedKVs = std::vector<std::pair<bool, ::ps::KVPairs<T>>>;
using Slicer = std::function<void(int64_t ts, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &ranges,
SlicedKVs *sliced, const std::map<int64_t, int64_t> &attrs)>;
using ::ps::SimpleApp::obj_;
explicit WorkerProxy(int64_t app_id, int64_t customer_id, int64_t lookup_customer_id, int64_t general_customer_id)
: Worker(app_id, customer_id) {
server_num_ = ::ps::NumServers();
MS_LOG(INFO) << "Server num:" << server_num_;
PSContext::instance()->SetPSRankId(::ps::MyRank());
using std::placeholders::_1;
using std::placeholders::_2;
using std::placeholders::_3;
using std::placeholders::_4;
using std::placeholders::_5;
lookup_customer_ = std::unique_ptr<::ps::Customer>(
new ::ps::Customer(app_id, lookup_customer_id, std::bind(&WorkerProxy<T>::ProcessLookupResult, this, _1)));
general_customer_ = std::unique_ptr<::ps::Customer>(
new ::ps::Customer(app_id, general_customer_id, std::bind(&WorkerProxy<T>::ProcessResponse, this, _1)));
lookup_slicer_ = std::bind(&WorkerProxy<T>::LookupIdSlicer, this, _1, _2, _3, _4, _5);
sparse_slicer_ = std::bind(&WorkerProxy<T>::SparseSlicer, this, _1, _2, _3, _4, _5);
broadcast_slicer_ = std::bind(&WorkerProxy<T>::BroadcastSlicer, this, _1, _2, _3, _4, _5);
round_robin_slicer_ = std::bind(&WorkerProxy<T>::RoundRobinSlicer, this, _1, _2, _3, _4, _5);
worker_init_embedding_slicer_ = std::bind(&WorkerProxy<T>::WorkerInitEmbeddingSlicer, this, _1, _2, _3, _4, _5);
update_embedding_slicer_ = std::bind(&WorkerProxy<T>::UpdateEmbeddingSlicer, this, _1, _2, _3, _4, _5);
}
~WorkerProxy() override = default;
void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count);
void AddKeyToServerId(const ::ps::Key &key);
void EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
const ::ps::SArray<int> &lens, ::ps::SArray<T> *outs, int64_t cmd = 0,
const Callback &cb = nullptr, int64_t priority = 0);
int64_t InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals,
const ::ps::SArray<int> &lens = {}, const Callback &cb = nullptr, int64_t priority = 0);
void UpdateEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
const ::ps::SArray<T> &vals, const Callback &cb = nullptr, int64_t priority = 0);
bool IsReadyForPush(const Key &key);
bool IsReadyForPull(const Key &key);
void PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals, const ::ps::SArray<int> &lens = {},
int64_t cmd = 0, int64_t priority = 0);
void PushSparseData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals, const ::ps::SArray<int> &lens,
size_t grad_index, size_t indice_index, size_t first_dim_size, size_t outer_dim_size);
void PullData(const ::ps::SArray<::ps::Key> &keys, ::ps::SArray<T> *vals, ::ps::SArray<int> *lens = nullptr,
int64_t cmd = 0, int64_t priority = 0);
void Finalize();
private:
template <typename C>
int64_t AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids, C *vals, int64_t cmd,
const Callback &cb);
int64_t AddGeneralRspCB(const ::ps::SArray<::ps::Key> &keys, ::ps::SArray<T> *vals, ::ps::SArray<int> *lens,
int64_t cmd, const Callback &cb);
void LookupIdSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced, const std::map<int64_t, int64_t> &attrs);
void SparseSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced, const std::map<int64_t, int64_t> &attrs);
void BroadcastSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced, const std::map<int64_t, int64_t> &attrs);
void RoundRobinSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced,
const std::map<int64_t, int64_t> &attrs);
void WorkerInitEmbeddingSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced,
const std::map<int64_t, int64_t> &attrs);
void UpdateEmbeddingSlicer(int timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced,
const std::map<int64_t, int64_t> &attrs);
void ProcessLookupResult(const ::ps::Message &msg);
void ProcessResponse(const ::ps::Message &msg);
void Send(::ps::Customer *customer, int64_t timestamp, bool push, bool pull, int64_t cmd, const ::ps::KVPairs<T> &kvs,
const Slicer &slicer, std::map<int64_t, int64_t> attrs = {});
void AddKeyByHashMod(const ::ps::Key &key);
void PrepareSparseGradient(const size_t begin, const size_t end, const std::unordered_set<int> &distinct_ids,
const std::vector<std::pair<int, T *>> &indice_to_grad, const int *all_indice,
const size_t segment_size, T *gradient, int *indice);
void BuildSparseValue(const ::ps::SArray<int> &lengths, const size_t grad_index, const size_t indice_index,
const T *original_data, const T *grads, int *indices, ::ps::SArray<T> *reduced_data);
int64_t server_num_;
std::unique_ptr<::ps::Customer> lookup_customer_;
std::unique_ptr<::ps::Customer> general_customer_;
std::unordered_map<::ps::Key, std::shared_ptr<std::vector<::ps::Range>>> embedding_table_ranges_;
std::unordered_map<int64_t, std::vector<::ps::KVPairs<T>>> lookup_results_;
std::unordered_map<int64_t, std::map<int64_t, ::ps::KVPairs<T>>> gathered_response_;
std::mutex mutex_;
Slicer lookup_slicer_;
Slicer sparse_slicer_;
Slicer broadcast_slicer_;
Slicer round_robin_slicer_;
Slicer worker_init_embedding_slicer_;
Slicer update_embedding_slicer_;
std::unordered_map<int64_t, Callback> lookup_callbacks_;
std::unordered_map<int64_t, Callback> general_callbacks_;
std::unordered_map<int64_t, int64_t> expected_result_count_;
std::unordered_map<::ps::Key, int64_t> key_to_server_id_;
std::unordered_map<::ps::Key, size_t> embedding_row_cnt_;
};
template <typename T>
void WorkerProxy<T>::AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count) {
uint64_t begin = 0;
uint64_t end = 0;
for (int64_t i = 0; i < server_num_; i++) {
int64_t local_row_cnt = Util::LocalShard(row_count, i, server_num_);
if (i == 0) {
end = local_row_cnt - 1;
} else {
begin = end + 1;
end += local_row_cnt;
}
::ps::Range range(begin, end);
if (embedding_table_ranges_.count(key) == 0) {
embedding_table_ranges_[key] = std::make_shared<std::vector<::ps::Range>>();
MS_EXCEPTION_IF_NULL(embedding_table_ranges_[key]);
}
embedding_table_ranges_[key]->push_back(range);
}
embedding_row_cnt_[key] = row_count;
}
template <typename T>
void WorkerProxy<T>::AddKeyByHashMod(const ::ps::Key &key) {
if (server_num_ == 0) {
MS_LOG(EXCEPTION) << "Server number is invalid:0";
}
key_to_server_id_[key] = static_cast<int64_t>(key % server_num_);
MS_LOG(INFO) << "The server id of key " << key << " is " << key_to_server_id_[key];
}
template <typename T>
void WorkerProxy<T>::AddKeyToServerId(const ::ps::Key &key) {
AddKeyByHashMod(key);
}
template <typename T>
void WorkerProxy<T>::EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
const ::ps::SArray<int> &lens, ::ps::SArray<T> *outs, int64_t cmd,
const Callback &cb, int64_t priority) {
int64_t ts = AddLookupCB(keys, lookup_ids, outs, cmd, cb);
::ps::KVPairs<T> kvs;
kvs.keys = keys;
kvs.lens = lookup_ids;
kvs.priority = priority;
expected_result_count_[ts] = 0;
Send(lookup_customer_.get(), ts, true, true, cmd, kvs, lookup_slicer_);
int64_t expect_rt_count = expected_result_count_[ts];
lookup_customer_->AddResponse(ts, server_num_ - expect_rt_count);
lookup_customer_->WaitRequest(ts);
expected_result_count_.erase(ts);
}
template <typename T>
int64_t WorkerProxy<T>::InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals,
const ::ps::SArray<int> &lens, const Callback &cb, int64_t priority) {
int64_t ts = obj_->NewRequest(::ps::kServerGroup);
::ps::KVPairs<T> kvs;
kvs.keys = keys;
kvs.vals = vals;
kvs.lens = lens;
kvs.priority = priority;
Send(obj_, ts, true, false, kInitEmbeddingsCmd, kvs, broadcast_slicer_);
return ts;
}
template <typename T>
void WorkerProxy<T>::UpdateEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
const ::ps::SArray<T> &vals, const Callback &cb, int64_t priority) {
int ts = AddGeneralRspCB(keys, nullptr, nullptr, 0, nullptr);
::ps::KVPairs<T> kvs;
kvs.keys = keys;
kvs.lens = lookup_ids;
kvs.vals = vals;
kvs.priority = priority;
expected_result_count_[ts] = 0;
Send(general_customer_.get(), ts, true, false, kUpdateEmbeddingsCmd, kvs, update_embedding_slicer_);
if (expected_result_count_[ts] < server_num_) {
general_customer_->AddResponse(ts, server_num_ - expected_result_count_[ts]);
}
general_customer_->WaitRequest(ts);
expected_result_count_.erase(ts);
}
template <typename T>
bool WorkerProxy<T>::IsReadyForPush(const Key &key) {
::ps::SArray<T> result(1, 0);
PullData({key}, &result, nullptr, kCheckReadyForPushCmd);
if (result[0] > 0) {
return true;
} else {
return false;
}
}
template <typename T>
bool WorkerProxy<T>::IsReadyForPull(const Key &key) {
::ps::SArray<T> result(1, 0);
PullData({key}, &result, nullptr, kCheckReadyForPullCmd);
if (result[0] > 0) {
return true;
} else {
return false;
}
}
template <typename T>
void WorkerProxy<T>::PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals,
const ::ps::SArray<int> &lens, int64_t cmd, int64_t priority) {
int64_t ts = AddGeneralRspCB(keys, nullptr, nullptr, cmd, nullptr);
::ps::KVPairs<T> kvs;
kvs.keys = keys;
kvs.vals = vals;
kvs.lens = lens;
kvs.priority = priority;
if (embedding_table_ranges_.count(keys[0])) {
if (cmd == kInitWeightsCmd) {
Send(general_customer_.get(), ts, true, false, cmd, kvs, worker_init_embedding_slicer_);
} else {
Send(general_customer_.get(), ts, true, false, cmd, kvs, broadcast_slicer_);
}
} else {
Send(general_customer_.get(), ts, true, false, cmd, kvs, round_robin_slicer_);
}
if (expected_result_count_[ts] < server_num_) {
general_customer_->AddResponse(ts, server_num_ - expected_result_count_[ts]);
}
general_customer_->WaitRequest(ts);
}
template <typename T>
void WorkerProxy<T>::PushSparseData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals,
const ::ps::SArray<int> &lens, size_t grad_index, size_t indice_index,
size_t first_dim_size, size_t outer_dim_size) {
int64_t ts = AddGeneralRspCB(keys, nullptr, nullptr, 0, nullptr);
::ps::KVPairs<T> kvs;
kvs.keys = keys;
kvs.vals = vals;
kvs.lens = lens;
const int64_t cmd = 0;
if (embedding_table_ranges_.count(keys[0])) {
std::map<int64_t, int64_t> attrs{{0, grad_index}, {1, indice_index}, {2, first_dim_size}, {3, outer_dim_size}};
Send(general_customer_.get(), ts, true, false, cmd, kvs, sparse_slicer_, attrs);
} else {
Send(general_customer_.get(), ts, true, false, cmd, kvs, round_robin_slicer_);
}
if (expected_result_count_[ts] < server_num_) {
general_customer_->AddResponse(ts, server_num_ - expected_result_count_[ts]);
}
general_customer_->WaitRequest(ts);
}
template <typename T>
void WorkerProxy<T>::PullData(const ::ps::SArray<::ps::Key> &keys, ::ps::SArray<T> *vals, ::ps::SArray<int> *lens,
int64_t cmd, int64_t priority) {
MS_EXCEPTION_IF_NULL(vals);
int64_t ts = AddGeneralRspCB(keys, vals, lens, cmd, nullptr);
::ps::KVPairs<T> kvs;
kvs.keys = keys;
kvs.priority = priority;
if (embedding_table_ranges_.count(keys[0])) {
Send(general_customer_.get(), ts, false, true, cmd, kvs, broadcast_slicer_);
} else {
Send(general_customer_.get(), ts, false, true, cmd, kvs, round_robin_slicer_);
}
if (expected_result_count_[ts] < server_num_) {
general_customer_->AddResponse(ts, server_num_ - expected_result_count_[ts]);
}
general_customer_->WaitRequest(ts);
}
template <typename T>
void WorkerProxy<T>::Finalize() {
int64_t ts = obj_->NewRequest(::ps::kServerGroup);
::ps::KVPairs<T> kvs;
kvs.keys.push_back(0);
kvs.vals.push_back(0.0f);
Send(obj_, ts, true, false, kFinalizeCmd, kvs, broadcast_slicer_);
obj_->WaitRequest(ts);
::ps::Finalize(0, true);
}
template <typename T>
template <typename C>
int64_t WorkerProxy<T>::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
C *lookup_result, int64_t cmd, const Callback &cb) {
MS_EXCEPTION_IF_NULL(lookup_result);
int64_t ts = lookup_customer_->NewRequest(::ps::kServerGroup);
const auto &callback = [this, ts, keys, lookup_ids, lookup_result, cb]() mutable {
mutex_.lock();
auto &kvs = lookup_results_[ts];
mutex_.unlock();
if (lookup_ids.empty()) {
MS_LOG(EXCEPTION) << "Lookup id is empty.";
}
int64_t single_id_len = SizeToLong(lookup_result->size() / lookup_ids.size());
std::unordered_map<Key, std::shared_ptr<std::pair<T *, int64_t>>> id_addr_map;
for (const auto &s : kvs) {
int64_t offset = 0;
for (size_t i = 0; i < s.keys.size(); i++) {
const Key &key = s.keys[i];
T *addr = s.vals.data() + offset;
offset += single_id_len;
id_addr_map[key] = std::make_shared<std::pair<T *, int64_t>>(std::make_pair(addr, single_id_len));
MS_EXCEPTION_IF_NULL(id_addr_map[key]);
}
}
T *result_addr = lookup_result->data();
MS_EXCEPTION_IF_NULL(result_addr);
int64_t offset = 0;
size_t dst_size = 0;
size_t src_size = 0;
void *dst_data = nullptr;
void *src_data = nullptr;
for (size_t i = 0; i < lookup_ids.size(); i++) {
if (id_addr_map.count(lookup_ids[i]) == 0) {
offset += single_id_len;
continue;
}
auto &pair = id_addr_map[static_cast<Key>(lookup_ids[i])];
int64_t size = single_id_len * sizeof(T);
dst_size = size;
src_size = size;
dst_data = result_addr + offset;
src_data = pair->first;
MS_EXCEPTION_IF_NULL(dst_data);
MS_EXCEPTION_IF_NULL(src_data);
auto ret = memcpy_s(dst_data, dst_size, src_data, src_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
return;
}
offset += single_id_len;
}
mutex_.lock();
lookup_results_.erase(ts);
mutex_.unlock();
if (cb) cb();
};
lookup_callbacks_[ts] = callback;
return ts;
}
template <typename T>
int64_t WorkerProxy<T>::AddGeneralRspCB(const ::ps::SArray<::ps::Key> &keys, ::ps::SArray<T> *vals,
::ps::SArray<int> *lens, int64_t cmd, const Callback &cb) {
int64_t ts = general_customer_->NewRequest(::ps::kServerGroup);
const auto &callback = [this, ts, keys, vals, lens, cb]() mutable {
mutex_.lock();
std::map<int64_t, ::ps::KVPairs<T>> server_kvs = gathered_response_[ts];
mutex_.unlock();
vals->clear();
for (auto kvs : server_kvs) {
for (auto val : kvs.second.vals) {
vals->push_back(val);
}
if (lens) {
for (auto len : kvs.second.lens) {
lens->push_back(len);
}
}
}
mutex_.lock();
gathered_response_.erase(ts);
mutex_.unlock();
if (cb) {
cb();
}
};
general_callbacks_[ts] = callback;
return ts;
}
template <typename T>
void WorkerProxy<T>::LookupIdSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced,
const std::map<int64_t, int64_t> &attrs) {
MS_EXCEPTION_IF_NULL(sliced);
int32_t *lookup_ids = send.lens.data();
size_t id_size = send.lens.size();
const Key &key = send.keys[0];
const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]);
sliced->resize(ranges.size());
for (size_t i = 0; i < ranges.size(); i++) {
const ::ps::Range &range = ranges[i];
const auto &begin = range.begin();
const auto &end = range.end();
std::unordered_set<int64_t> unique_ids;
auto &kvs = sliced->at(i).second;
kvs.keys.push_back(key);
kvs.vals.push_back(0.0f);
for (size_t j = 0; j < id_size; j++) {
auto lookup_id = static_cast<uint64_t>(lookup_ids[j]);
// If lookup_id is out of range, like negative number, unique_ids will not contain it.
// Servers always get lookup_ids in its embedding table range.
if (lookup_id >= begin && lookup_id <= end) {
unique_ids.insert(lookup_id);
}
}
for (const auto &lookup_id : unique_ids) {
kvs.keys.push_back(lookup_id);
kvs.vals.push_back(0.0f);
}
if (kvs.keys.size() <= 1) {
sliced->at(i).first = false;
} else {
sliced->at(i).first = true;
expected_result_count_[timestamp] += 1;
}
}
}
template <typename T>
void WorkerProxy<T>::SparseSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced,
const std::map<int64_t, int64_t> &attrs) {
MS_EXCEPTION_IF_NULL(sliced);
// Init variables
T *data = send.vals.data();
if (attrs.count(0) == 0 || attrs.count(1) == 0 || attrs.count(2) == 0 || attrs.count(3) == 0) {
MS_LOG(EXCEPTION) << "Invalid attrs keys";
}
auto iter = attrs.find(0);
size_t grad_index = static_cast<size_t>(iter->second);
iter = attrs.find(1);
size_t indice_index = static_cast<size_t>(iter->second);
iter = attrs.find(2);
size_t first_dim_size = static_cast<size_t>(iter->second);
iter = attrs.find(3);
size_t outer_dim_size = static_cast<size_t>(iter->second);
int grad_size = send.lens[grad_index];
int indice_size = send.lens[indice_index];
int segment_size = grad_size / indice_size;
int64_t grad_offset = 0;
int64_t indice_offset = 0;
for (size_t i = 0; i < grad_index; i++) {
grad_offset += send.lens[i];
}
for (size_t j = 0; j < indice_index; j++) {
indice_offset += send.lens[j];
}
T *grad_data = data + grad_offset;
int *indice_data = reinterpret_cast<int *>(data) + indice_offset;
// Build the mappings of indice to gradient
std::vector<std::pair<int, T *>> indice_to_grads;
for (int i = 0; i < indice_size; i++) {
int indice = indice_data[i];
T *grad = grad_data + i * segment_size;
indice_to_grads.push_back(std::make_pair(indice, grad));
}
const Key &key = send.keys[0];
const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]);
sliced->resize(ranges.size());
// Construct reduced sparse data for each server
for (size_t i = 0; i < ranges.size(); i++) {
const ::ps::Range &range = ranges[i];
const auto &begin = range.begin();
const auto &end = range.end();
auto &kvs = sliced->at(i).second;
kvs.keys = send.keys;
kvs.lens = send.lens;
// Prepare the sparse gradient and indice
std::vector<int> indice_ids;
std::unordered_set<int> distinct_ids;
for (int j = 0; j < indice_size; j++) {
size_t indice = static_cast<size_t>(indice_data[j]);
if (indice >= begin && indice <= end) {
indice_ids.push_back(indice);
distinct_ids.insert(indice);
}
}
size_t indices_size = indice_ids.size();
if (indices_size > 0) {
int slice_segment_size = indices_size * segment_size;
std::vector<T> src_grad_data(slice_segment_size);
std::vector<int> src_indice_data(indices_size);
PrepareSparseGradient(begin, end, distinct_ids, indice_to_grads, indice_data, segment_size, src_grad_data.data(),
src_indice_data.data());
// Reduce the sparse gradient and indice
std::vector<T> new_grad(slice_segment_size);
std::vector<int> new_indices(indices_size);
mindspore::kernel::SparseGradient<int> unique_sparse_grad({new_grad.data(), new_indices.data(), indices_size});
Util::ReduceSparseGradient(src_grad_data.data(), src_indice_data.data(), indices_size, segment_size,
first_dim_size, outer_dim_size, &unique_sparse_grad);
// Update the length of reduce sparse gradient and indice
::ps::SArray<int> reduced_lens;
reduced_lens.CopyFrom(kvs.lens);
reduced_lens[grad_index] = unique_sparse_grad.indices_size_ * segment_size;
reduced_lens[indice_index] = unique_sparse_grad.indices_size_;
// Build the sparse value to be sent
size_t total_size = std::accumulate(reduced_lens.begin(), reduced_lens.end(), 0, std::plus<int>());
::ps::SArray<T> reduced_data(total_size, 0);
BuildSparseValue(reduced_lens, grad_index, indice_index, data, unique_sparse_grad.value_,
unique_sparse_grad.indices_, &reduced_data);
kvs.lens = reduced_lens;
kvs.vals = reduced_data;
}
if (indices_size <= 0) {
::ps::SArray<T> no_keys;
::ps::SArray<T> no_vals;
::ps::SArray<T> no_lens;
no_keys.push_back(key);
no_vals.push_back(-100);
kvs.vals = no_vals;
kvs.lens = no_lens;
}
sliced->at(i).first = true;
expected_result_count_[timestamp] += 1;
}
}
template <typename T>
void WorkerProxy<T>::PrepareSparseGradient(const size_t begin, const size_t end,
const std::unordered_set<int> &distinct_ids,
const std::vector<std::pair<int, T *>> &indice_to_grads,
const int *all_indice, const size_t segment_size, T *gradient,
int *indices) {
MS_EXCEPTION_IF_NULL(all_indice);
MS_EXCEPTION_IF_NULL(gradient);
MS_EXCEPTION_IF_NULL(indices);
int64_t offset = 0;
int64_t index = 0;
size_t segment_data_size = segment_size * sizeof(T);
size_t dst_size;
size_t src_size;
void *dst_data = nullptr;
void *src_data = nullptr;
for (auto &pair : indice_to_grads) {
if (distinct_ids.count(pair.first) == 0) {
continue;
}
indices[index++] = pair.first;
dst_size = segment_data_size;
src_size = segment_data_size;
dst_data = gradient + offset;
src_data = pair.second;
MS_EXCEPTION_IF_NULL(dst_data);
MS_EXCEPTION_IF_NULL(src_data);
auto ret = memcpy_s(gradient + offset, dst_size, pair.second, src_size);
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
return;
}
offset += segment_size;
}
}
template <typename T>
void WorkerProxy<T>::BuildSparseValue(const ::ps::SArray<int> &lengths, const size_t grad_index,
const size_t indice_index, const T *original_data, const T *grads, int *indices,
::ps::SArray<T> *reduced_data) {
MS_EXCEPTION_IF_NULL(original_data);
MS_EXCEPTION_IF_NULL(grads);
MS_EXCEPTION_IF_NULL(indices);
MS_EXCEPTION_IF_NULL(reduced_data);
int64_t offset = 0;
size_t dst_size = 0;
size_t src_size = 0;
void *dst_data = nullptr;
void *src_data = nullptr;
for (size_t i = 0; i < lengths.size(); i++) {
if (i != grad_index && i != indice_index) {
int data_size = lengths[i] * sizeof(T);
dst_size = data_size;
src_size = data_size;
dst_data = reduced_data->data() + offset;
src_data = const_cast<T *>(original_data) + offset;
MS_EXCEPTION_IF_NULL(dst_data);
MS_EXCEPTION_IF_NULL(src_data);
auto ret = memcpy_s(dst_data, dst_size, src_data, src_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
return;
}
}
offset += lengths[i];
}
// Fill the reduced gradient
int64_t grad_offset = 0;
for (size_t i = 0; i < grad_index; i++) {
grad_offset += lengths[i];
}
int64_t data_size = lengths[grad_index] * sizeof(T);
dst_size = data_size;
src_size = data_size;
dst_data = reduced_data->data() + grad_offset;
src_data = const_cast<T *>(grads);
MS_EXCEPTION_IF_NULL(dst_data);
MS_EXCEPTION_IF_NULL(src_data);
auto ret = memcpy_s(dst_data, dst_size, src_data, src_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
return;
}
// Fill the reduced indice
int64_t indice_offset = grad_offset + lengths[grad_index];
data_size = lengths[indice_index] * sizeof(T);
T *indice_data = reduced_data->data() + indice_offset;
dst_size = data_size;
src_size = data_size;
dst_data = indice_data;
src_data = indices;
MS_EXCEPTION_IF_NULL(dst_data);
MS_EXCEPTION_IF_NULL(src_data);
ret = memcpy_s(dst_data, dst_size, src_data, src_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
return;
}
}
template <typename T>
void WorkerProxy<T>::BroadcastSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced,
const std::map<int64_t, int64_t> &attr) {
MS_EXCEPTION_IF_NULL(sliced);
sliced->resize(server_num_);
for (int64_t i = 0; i < server_num_; i++) {
sliced->at(i).first = true;
sliced->at(i).second = send;
expected_result_count_[timestamp] += 1;
}
}
template <typename T>
void WorkerProxy<T>::RoundRobinSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced,
const std::map<int64_t, int64_t> &attr) {
MS_EXCEPTION_IF_NULL(sliced);
sliced->resize(server_num_);
auto keys = send.keys;
auto vals = send.vals;
auto lens = send.lens;
int64_t server_id, len;
::ps::Key param_key;
for (size_t i = 0; i < keys.size(); i++) {
param_key = keys[i];
server_id = key_to_server_id_[param_key];
if (!sliced->at(server_id).first) {
sliced->at(server_id).first = true;
expected_result_count_[timestamp] += 1;
}
::ps::KVPairs<T> &server_kv_pairs = sliced->at(server_id).second;
server_kv_pairs.keys.push_back(param_key);
if (vals.empty()) {
continue;
}
len = lens[i];
int64_t offset = std::accumulate(lens.begin(), lens.begin() + i, 0);
auto val_begin = vals.begin() + offset;
auto val_end = val_begin + len;
for (auto iter = val_begin; iter != val_end; iter++) {
server_kv_pairs.vals.push_back(*iter);
}
server_kv_pairs.lens.push_back(len);
}
}
template <typename T>
void WorkerProxy<T>::WorkerInitEmbeddingSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send,
const std::vector<::ps::Range> &,
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced,
const std::map<int64_t, int64_t> &attrs) {
MS_EXCEPTION_IF_NULL(sliced);
sliced->resize(server_num_);
auto keys = send.keys;
auto vals = send.vals;
auto lens = send.lens;
size_t col_cnt = lens[0] / embedding_row_cnt_[keys[0]];
const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[keys[0]]);
for (size_t i = 0; i < ranges.size(); i++) {
size_t offset_begin = ranges[i].begin() * col_cnt;
size_t offset_end = (ranges[i].end() + 1) * col_cnt;
::ps::KVPairs<T> kvs;
kvs.keys = keys;
kvs.vals = vals.segment(offset_begin, offset_end);
kvs.lens.push_back(offset_end - offset_begin);
sliced->at(i).first = true;
sliced->at(i).second = kvs;
}
}
template <typename T>
void WorkerProxy<T>::UpdateEmbeddingSlicer(int timestamp, const ::ps::KVPairs<T> &send,
const std::vector<::ps::Range> &,
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced,
const std::map<int64_t, int64_t> &attrs) {
MS_EXCEPTION_IF_NULL(sliced);
T *embedding_vals = send.vals.data();
int *lookup_ids = send.lens.data();
size_t val_size = send.vals.size();
size_t id_size = send.lens.size();
size_t embedding_dim = val_size / id_size;
const Key &key = send.keys[0];
const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]);
sliced->resize(ranges.size());
for (size_t i = 0; i < ranges.size(); i++) {
const ::ps::Range &range = ranges[i];
const auto &begin = range.begin();
const auto &end = range.end();
auto &kvs = sliced->at(i).second;
kvs.keys.push_back(key);
for (size_t j = 0; j < id_size; j++) {
auto lookup_id = static_cast<uint64_t>(lookup_ids[j]);
if (lookup_id >= begin && lookup_id <= end) {
kvs.keys.push_back(lookup_id);
for (size_t k = 0; k < embedding_dim; k++) {
kvs.vals.push_back(embedding_vals[j * embedding_dim + k]);
}
}
}
if (kvs.keys.size() <= 1) {
sliced->at(i).first = false;
} else {
sliced->at(i).first = true;
expected_result_count_[timestamp] += 1;
}
}
}
template <typename T>
void WorkerProxy<T>::ProcessLookupResult(const ::ps::Message &msg) {
int64_t ts = msg.meta.timestamp;
if (msg.meta.pull) {
CHECK_GE(msg.data.size(), (size_t)2);
::ps::KVPairs<T> kvs;
kvs.keys = msg.data[0];
kvs.vals = msg.data[1];
if (msg.data.size() > (size_t)2) {
kvs.lens = msg.data[2];
}
mutex_.lock();
lookup_results_[ts].push_back(kvs);
mutex_.unlock();
}
if (lookup_customer_->NumResponse(ts) + 1 == server_num_) {
const auto &cb = lookup_callbacks_[ts];
cb();
lookup_callbacks_.erase(ts);
}
}
template <typename T>
void WorkerProxy<T>::ProcessResponse(const ::ps::Message &msg) {
int64_t ts = msg.meta.timestamp;
if (msg.meta.pull) {
CHECK_GE(msg.data.size(), (size_t)2);
::ps::KVPairs<T> kvs;
kvs.keys = msg.data[0];
kvs.vals = msg.data[1];
if (msg.data.size() > (size_t)2) {
kvs.lens = msg.data[2];
}
mutex_.lock();
int rsp_server_rank = ::ps::Postoffice::Get()->IDtoRank(msg.meta.sender);
gathered_response_[ts][rsp_server_rank] = kvs;
mutex_.unlock();
if (general_customer_->NumResponse(ts) + 1 == server_num_) {
const auto &cb = general_callbacks_[ts];
cb();
general_callbacks_.erase(ts);
}
}
}
template <typename T>
void WorkerProxy<T>::Send(::ps::Customer *customer, int64_t timestamp, bool push, bool pull, int64_t cmd,
const ::ps::KVPairs<T> &kvs, const Slicer &slicer, std::map<int64_t, int64_t> attrs) {
MS_EXCEPTION_IF_NULL(customer);
SlicedKVs sliced;
slicer(timestamp, kvs, ::ps::Postoffice::Get()->GetServerKeyRanges(), &sliced, attrs);
for (size_t i = 0; i < sliced.size(); i++) {
const auto &s = sliced[i];
if (!s.first) continue;
::ps::Message msg;
msg.meta.app_id = customer->app_id();
msg.meta.customer_id = customer->customer_id();
msg.meta.request = true;
msg.meta.push = push;
msg.meta.pull = pull;
msg.meta.head = cmd;
msg.meta.timestamp = timestamp;
msg.meta.recver = ::ps::Postoffice::Get()->ServerRankToID(i);
msg.meta.priority = kvs.priority;
const auto &kvs = s.second;
if (kvs.keys.size()) {
msg.AddData(kvs.keys);
msg.AddData(kvs.vals);
if (kvs.lens.size()) {
msg.AddData(kvs.lens);
}
}
::ps::Postoffice::Get()->van()->Send(msg);
}
}
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_WORKER_PROXY_H_

View File

@ -24,7 +24,7 @@ namespace mindspore {
namespace device {
void KernelRuntimeManager::ClearRuntimeResource() {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
if (ps::Util::IsRoleOfWorker() && ps::PsDataPrefetch::GetInstance().cache_enable()) {
if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) {
ps::ps_cache_instance.SyncEmbeddingTable();
}
#endif

View File

@ -78,7 +78,6 @@ export DEVICE_NUM=8
export RANK_SIZE=8
export RANK_TABLE_FILE=$PATH1
export MS_COMM_TYPE=zmq
export MS_SCHED_NUM=1
export MS_WORKER_NUM=$RANK_SIZE
export MS_SERVER_NUM=8

View File

@ -70,7 +70,6 @@ fi
export DEVICE_NUM=8
export RANK_SIZE=8
export MS_COMM_TYPE=zmq
export MS_SCHED_NUM=1
export MS_WORKER_NUM=8
export MS_SERVER_NUM=8

View File

@ -27,7 +27,6 @@ export EPOCH_SIZE=$2
export DEVICE_TARGET=$3
export DATASET=$4
export MS_COMM_TYPE=zmq
export MS_SCHED_NUM=1
export MS_WORKER_NUM=$RANK_SIZE
export LOCAL_WORKER_NUM=$5

View File

@ -25,7 +25,6 @@ export RANK_SIZE=$1
export EPOCH_SIZE=$2
export DEVICE_TARGET=$3
export DATASET=$4
export MS_COMM_TYPE=zmq
export MS_SCHED_NUM=1
export MS_WORKER_NUM=$RANK_SIZE
export MS_SERVER_NUM=$5

View File

@ -23,7 +23,6 @@ self_path=$(dirname "${script_self}")
export EPOCH_SIZE=$1
export DEVICE_TARGET=$2
export DATASET=$3
export MS_COMM_TYPE=zmq
export MS_SCHED_NUM=1
export MS_WORKER_NUM=1
export MS_SERVER_NUM=$4

View File

@ -15,8 +15,7 @@
# ============================================================================
execute_path=$(pwd)
self_path=$(dirname "${script_self}")
export MS_COMM_TYPE=zmq
self_path=$(dirname $0)
export MS_SCHED_NUM=1
DEVICE_TARGET=$1
export MS_WORKER_NUM=$2

View File

@ -15,8 +15,7 @@
# ============================================================================
execute_path=$(pwd)
self_path=$(dirname "${script_self}")
export MS_COMM_TYPE=zmq
self_path=$(dirname $0)
export MS_SCHED_NUM=1
DEVICE_TARGET=$1
DATASET_PATH=$2

View File

@ -15,8 +15,7 @@
# ============================================================================
execute_path=$(pwd)
self_path=$(dirname "${script_self}")
export MS_COMM_TYPE=zmq
self_path=$(dirname $0)
export MS_SCHED_NUM=1
DEVICE_TARGET=$1
export MS_WORKER_NUM=$2

View File

@ -15,8 +15,7 @@
# ============================================================================
execute_path=$(pwd)
self_path=$(dirname "${script_self}")
export MS_COMM_TYPE=zmq
self_path=$(dirname $0)
export MS_SCHED_NUM=1
DEVICE_TARGET=$1
DATASET_PATH=$2

View File

@ -150,6 +150,8 @@ list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/internal/parame
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/scheduler.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/optimizer_info.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/optimizer_info_builder.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/worker.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/parameter_server.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/gpu/gpu_ps_cache.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc")

View File

@ -1,255 +0,0 @@
diff -Npur ps-lite-master/include/dmlc/base.h ps-lite-master-new/include/dmlc/base.h
--- ps-lite-master/include/dmlc/base.h 2020-02-29 13:59:55.000000000 +0800
+++ ps-lite-master-new/include/dmlc/base.h 2020-07-01 11:56:50.444833389 +0800
@@ -8,7 +8,7 @@
/*! \brief whether use glog for logging */
#ifndef DMLC_USE_GLOG
-#define DMLC_USE_GLOG 0
+#define DMLC_USE_GLOG 1
#endif
/*!
diff -Npur ps-lite-master/include/dmlc/logging.h ps-lite-master-new/include/dmlc/logging.h
--- ps-lite-master/include/dmlc/logging.h 2020-02-29 13:59:55.000000000 +0800
+++ ps-lite-master-new/include/dmlc/logging.h 2020-07-08 21:35:33.334584767 +0800
@@ -52,7 +52,7 @@ struct Error : public std::runtime_error
namespace dmlc {
inline void InitLogging(const char* argv0) {
- google::InitGoogleLogging(argv0);
+ //google::InitGoogleLogging(argv0);
}
} // namespace dmlc
diff -Npur ps-lite-master/make/deps.mk ps-lite-master-new/make/deps.mk
--- ps-lite-master/make/deps.mk 2020-02-29 13:59:55.000000000 +0800
+++ ps-lite-master-new/make/deps.mk 2020-06-17 10:35:46.253837426 +0800
@@ -1,69 +1,7 @@
# Install dependencies
-
-URL1=https://raw.githubusercontent.com/mli/deps/master/build
-URL2=https://github.com/google/protobuf/releases/download/v3.5.1
-ifndef WGET
-WGET = wget
-endif
-
-# protobuf
-PROTOBUF = ${DEPS_PATH}/include/google/protobuf/message.h
-${PROTOBUF}:
- $(eval FILE=protobuf-cpp-3.5.1.tar.gz)
- $(eval DIR=protobuf-3.5.1)
- rm -rf $(FILE) $(DIR)
- $(WGET) $(URL2)/$(FILE) && tar --no-same-owner -zxf $(FILE)
- cd $(DIR) && export CFLAGS=-fPIC && export CXXFLAGS=-fPIC && ./configure -prefix=$(DEPS_PATH) && $(MAKE) && $(MAKE) install
- rm -rf $(FILE) $(DIR)
-
# zmq
-ZMQ = ${DEPS_PATH}/include/zmq.h
+ZMQ = $(MS_ZMQ_INSTALL_PATH)/lib/libzmq.a
${ZMQ}:
- $(eval FILE=zeromq-4.1.4.tar.gz)
- $(eval DIR=zeromq-4.1.4)
- rm -rf $(FILE) $(DIR)
- $(WGET) $(URL1)/$(FILE) && tar --no-same-owner -zxf $(FILE)
- cd $(DIR) && export CFLAGS=-fPIC && export CXXFLAGS=-fPIC && ./configure -prefix=$(DEPS_PATH) --with-libsodium=no --with-libgssapi_krb5=no && $(MAKE) && $(MAKE) install
- rm -rf $(FILE) $(DIR)
-
-# lz4
-LZ4 = ${DEPS_PATH}/include/lz4.h
-${LZ4}:
- $(eval FILE=lz4-r129.tar.gz)
- $(eval DIR=lz4-r129)
- rm -rf $(FILE) $(DIR)
- wget $(URL1)/$(FILE) && tar --no-same-owner -zxf $(FILE)
- cd $(DIR) && $(MAKE) && PREFIX=$(DEPS_PATH) $(MAKE) install
- rm -rf $(FILE) $(DIR)
-
-# cityhash
-CITYHASH = ${DEPS_PATH}/include/city.h
-${CITYHASH}:
- $(eval FILE=cityhash-1.1.1.tar.gz)
- $(eval DIR=cityhash-1.1.1)
- rm -rf $(FILE) $(DIR)
- wget $(URL1)/$(FILE) && tar --no-same-owner -zxf $(FILE)
- cd $(DIR) && ./configure -prefix=$(DEPS_PATH) --enable-sse4.2 && $(MAKE) CXXFLAGS="-g -O3 -msse4.2" && $(MAKE) install
- rm -rf $(FILE) $(DIR)
-
-
-# # gflags
-# ${DEPS_PATH}/include/google/gflags.h:
-# $(eval FILE=gflags-2.0-no-svn-files.tar.gz)
-# $(eval DIR=gflags-2.0)
-# rm -rf $(FILE) $(DIR)
-# wget $(URL)/$(FILE) && tar -zxf $(FILE)
-# cd $(DIR) && ./configure -prefix=$(DEPS_PATH) && $(MAKE) && $(MAKE) install
-# rm -rf $(FILE) $(DIR)
-# gflags: | ${DEPS_PATH}/include/google/gflags.h
+ cd $(MS_ZMQ_DIR) && export CFLAGS="-fPIC -D_GLIBCXX_USE_CXX11_ABI=0" && export CXXFLAGS=-fPIC && ./configure -prefix=$(MS_ZMQ_INSTALL_PATH) --with-libsodium=no --with-libgssapi_krb5=no && $(MAKE) && $(MAKE) install
-# # glog
-# ${DEPS_PATH}/include/glog/logging.h: | ${DEPS_PATH}/include/google/gflags.h
-# $(eval FILE=v0.3.4.tar.gz)
-# $(eval DIR=glog-0.3.4)
-# rm -rf $(FILE) $(DIR)
-# wget https://github.com/google/glog/archive/$(FILE) && tar -zxf $(FILE)
-# cd $(DIR) && ./configure -prefix=$(DEPS_PATH) --with-gflags=$(DEPS_PATH) && $(MAKE) && $(MAKE) install
-# rm -rf $(FILE) $(DIR)
-# glog: | ${DEPS_PATH}/include/glog/logging.h
diff -Npur ps-lite-master/make/ps.mk ps-lite-master-new/make/ps.mk
--- ps-lite-master/make/ps.mk 2020-02-29 13:59:55.000000000 +0800
+++ ps-lite-master-new/make/ps.mk 2020-06-05 09:28:35.337740291 +0800
@@ -9,5 +9,5 @@ ifeq ($(USE_KEY32), 1)
ADD_CFLAGS += -DUSE_KEY32=1
endif
-PS_LDFLAGS_SO = -L$(DEPS_PATH)/lib -lprotobuf-lite -lzmq
-PS_LDFLAGS_A = $(addprefix $(DEPS_PATH)/lib/, libprotobuf-lite.a libzmq.a)
+PS_LDFLAGS_SO = -L$(MS_ZMQ_INSTALL_PATH)/lib -lzmq -L$(MS_PROTO_LIB_DIR) -lprotobuf-lite
+PS_LDFLAGS_A = $(addprefix $(MS_ZMQ_INSTALL_PATH)/lib -L$(MS_PROTO_LIB_DIR), libprotobuf-lite.a libzmq.a)
diff -Npur ps-lite-master/Makefile ps-lite-master-new/Makefile
--- ps-lite-master/Makefile 2020-02-29 13:59:55.000000000 +0800
+++ ps-lite-master-new/Makefile 2020-06-17 11:09:20.240322660 +0800
@@ -12,13 +12,24 @@ ifndef DEPS_PATH
DEPS_PATH = $(shell pwd)/deps
endif
+MS_PROTO_DIR = @protobuf_DIRPATH@
+MS_GLOG_DIR = @glog_DIRPATH@
+MS_ZMQ_DIR = @zeromq_DIRPATH@
+
+MS_PROTO_LIB_DIR = @protobuf_LIBPATH@
+MS_GLOG_LIB_DIR = @glog_LIBPATH@
+MS_ZMQ_INSTALL_PATH = $(MS_ZMQ_DIR)/zmq_install
ifndef PROTOC
-PROTOC = ${DEPS_PATH}/bin/protoc
+PROTOC = $(MS_PROTO_DIR)/bin/protoc
endif
-INCPATH = -I./src -I./include -I$(DEPS_PATH)/include
-CFLAGS = -std=c++11 -msse2 -fPIC -O3 -ggdb -Wall -finline-functions $(INCPATH) $(ADD_CFLAGS)
+INCPATH = -I./src -I./include -I$(MS_ZMQ_INSTALL_PATH)/include
+INCPATH += -I$(MS_PROTO_DIR)/include
+INCPATH += -I$(MS_GLOG_DIR)/include
+
+CXXFLAGS = -D_GLIBCXX_USE_CXX11_ABI=0
+CFLAGS = -std=c++11 -fPIC -O3 -ggdb -Wall -finline-functions $(INCPATH) $(ADD_CFLAGS) -D_GLIBCXX_USE_CXX11_ABI=0
LIBS = -pthread
ifdef USE_IBVERBS
@@ -30,6 +41,7 @@ ifdef ASAN
CFLAGS += -fsanitize=address -fno-omit-frame-pointer -fno-optimize-sibling-calls
endif
+LIBS += -L$(MS_GLOG_LIB_DIR) -lglog
all: ps test
@@ -51,9 +63,9 @@ build/libps.a: $(OBJS)
build/%.o: src/%.cc ${ZMQ} src/meta.pb.h
@mkdir -p $(@D)
$(CXX) $(INCPATH) -std=c++11 -MM -MT build/$*.o $< >build/$*.d
- $(CXX) $(CFLAGS) $(LIBS) -c $< -o $@
+ $(CXX) $(CFLAGS) $(CXXFLAGS) $(LIBS) -c $< -o $@
-src/%.pb.cc src/%.pb.h : src/%.proto ${PROTOBUF}
+src/%.pb.cc src/%.pb.h : src/%.proto
$(PROTOC) --cpp_out=./src --proto_path=./src $<
-include build/*.d
diff -Npur ps-lite-master/src/ibverbs_van.h ps-lite-master-new/src/ibverbs_van.h
--- ps-lite-master/src/ibverbs_van.h 2020-02-29 13:59:55.000000000 +0800
+++ ps-lite-master-new/src/ibverbs_van.h 2020-06-02 20:52:11.076230014 +0800
@@ -145,15 +145,15 @@ class SimpleMempool {
total_allocated_size += new_mem_size;
}
- CHECK_NE(free_list.end(), it) << "Not enough memory";
+ //CHECK_NE(free_list.end(), it) << "Not enough memory";
CHECK_GE(it->first, proper_size);
char *addr = it->second;
size_t space_left = it->first - proper_size;
free_list.erase(it);
- CHECK_EQ(used_list.find(addr), used_list.end())
- << "Address is already allocated";
+ //CHECK_EQ(used_list.find(addr), used_list.end())
+ //<< "Address is already allocated";
used_list.emplace(addr, proper_size);
@@ -173,8 +173,8 @@ class SimpleMempool {
std::lock_guard<std::mutex> lk(mu_);
auto it = used_list.find(addr);
- CHECK_NE(used_list.end(), it)
- << "Cannot find info about address: " << (uintptr_t)addr;
+ //CHECK_NE(used_list.end(), it)
+ //<< "Cannot find info about address: " << (uintptr_t)addr;
size_t size = it->second;
used_list.erase(it);
@@ -208,7 +208,7 @@ class SimpleMempool {
// Convert the memory address to its associated RDMA memory region
inline struct ibv_mr *Addr2MR(char *addr) {
auto it = mr_list.lower_bound(addr);
- CHECK_NE(it, mr_list.end()) << "cannot find the associated memory region";
+ //CHECK_NE(it, mr_list.end()) << "cannot find the associated memory region";
return it->second;
}
};
@@ -330,7 +330,7 @@ class AddressPool {
CHECK(ptr);
uint32_t idx = indices_.front();
indices_.pop();
- CHECK_EQ(table_[idx], nullptr);
+ //CHECK_EQ(table_[idx], nullptr);
table_[idx] = ptr;
return idx;
}
@@ -636,7 +636,7 @@ class IBVerbsVan : public Van {
PBMeta meta;
PackMetaPB(msg.meta, &meta);
- CHECK_NE(endpoints_.find(remote_id), endpoints_.end());
+ //CHECK_NE(endpoints_.find(remote_id), endpoints_.end());
Endpoint *endpoint = endpoints_[remote_id].get();
MessageBuffer *msg_buf = new MessageBuffer();
diff -Npur ps-lite-master/src/van.cc ps-lite-master-new/src/van.cc
--- ps-lite-master/src/van.cc 2020-02-29 13:59:55.000000000 +0800
+++ ps-lite-master-new/src/van.cc 2020-06-02 20:52:43.330405828 +0800
@@ -448,6 +448,7 @@ void Van::PackMetaPB(const Meta& meta, P
if (meta.timestamp != Meta::kEmpty) pb->set_timestamp(meta.timestamp);
if (meta.body.size()) pb->set_body(meta.body);
pb->set_push(meta.push);
+ pb->set_pull(meta.pull);
pb->set_request(meta.request);
pb->set_simple_app(meta.simple_app);
pb->set_priority(meta.priority);
diff -Npur ps-lite-master/tests/test.mk ps-lite-master-new/tests/test.mk
--- ps-lite-master/tests/test.mk 2020-02-29 13:59:55.000000000 +0800
+++ ps-lite-master-new/tests/test.mk 2020-06-16 19:15:06.025087897 +0800
@@ -1,10 +1,10 @@
-TEST_SRC = $(wildcard tests/test_*.cc)
-TEST = $(patsubst tests/test_%.cc, tests/test_%, $(TEST_SRC))
+#TEST_SRC = $(wildcard tests/test_*.cc)
+#TEST = $(patsubst tests/test_%.cc, tests/test_%, $(TEST_SRC))
-# -ltcmalloc_and_profiler
-LDFLAGS = -Wl,-rpath,$(DEPS_PATH)/lib $(PS_LDFLAGS_SO) -pthread
-tests/% : tests/%.cc build/libps.a
- $(CXX) $(CFLAGS) $(LIBS) -MM -MT tests/$* $< >tests/$*.d
- $(CXX) $(CFLAGS) $(LIBS) -o $@ $(filter %.cc %.a, $^) $(LDFLAGS)
-
--include tests/*.d
+## -ltcmalloc_and_profiler
+#LDFLAGS = -Wl,-rpath,$(DEPS_PATH)/lib $(PS_LDFLAGS_SO) -pthread
+#tests/% : tests/%.cc build/libps.a
+# $(CXX) $(CFLAGS) $(LIBS) -MM -MT tests/$* $< >tests/$*.d
+# $(CXX) $(CFLAGS) $(LIBS) -o $@ $(filter %.cc %.a, $^) $(LDFLAGS)
+#
+#-include tests/*.d