forked from mindspore-Ecosystem/mindspore
replace ps-lite
This commit is contained in:
parent
e99c29c7d9
commit
db0a6f1e19
|
@ -1,22 +0,0 @@
|
|||
if(ENABLE_GITEE)
|
||||
set(REQ_URL "https://gitee.com/mirrors/ps-lite/repository/archive/34fd45cae457d59850fdcb2066467778d0673f21.zip")
|
||||
set(MD5 "0d1543b8dcb0bc3610637e1643c94eb4")
|
||||
else()
|
||||
set(REQ_URL "https://github.com/dmlc/ps-lite/archive/34fd45cae457d59850fdcb2066467778d0673f21.zip")
|
||||
set(MD5 "393c0e27b68bfaf96718caa3aa96f5a3")
|
||||
endif()
|
||||
|
||||
set(pslite_USE_STATIC_LIBS ON)
|
||||
if(${ENABLE_IBVERBS} STREQUAL "ON")
|
||||
set(pslite_CXXFLAGS "USE_IBVERBS=1")
|
||||
endif()
|
||||
mindspore_add_pkg(pslite
|
||||
LIBS ps
|
||||
URL ${REQ_URL}
|
||||
MD5 ${MD5}
|
||||
PATCHES ${CMAKE_SOURCE_DIR}/third_party/patch/pslite/ps_lite.patch001
|
||||
ONLY_MAKE True
|
||||
ONLY_MAKE_INCS include/*
|
||||
ONLY_MAKE_LIBS build/*)
|
||||
include_directories(${pslite_INC})
|
||||
add_library(mindspore::pslite ALIAS pslite::ps)
|
|
@ -1,5 +0,0 @@
|
|||
mindspore_add_pkg(zeromq
|
||||
VER 4.1.4
|
||||
HEAD_ONLY ./
|
||||
URL https://raw.githubusercontent.com/mli/deps/master/build/zeromq-4.1.4.tar.gz
|
||||
MD5 a611ecc93fffeb6d058c0e6edf4ad4fb)
|
|
@ -32,10 +32,6 @@ include(${CMAKE_SOURCE_DIR}/cmake/external_libs/flatbuffers.cmake)
|
|||
if(USE_GLOG)
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/glog.cmake)
|
||||
endif()
|
||||
if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/zeromq.cmake)
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/pslite.cmake)
|
||||
endif()
|
||||
|
||||
find_package(Python3)
|
||||
include_directories(${Python3_INCLUDE_DIRS})
|
||||
|
|
|
@ -339,8 +339,8 @@ elseif(CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
|||
target_link_libraries(_c_expression PRIVATE -Wl,-force_load mindspore -Wl,-noall_load)
|
||||
else()
|
||||
if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))
|
||||
target_link_libraries(mindspore mindspore::pslite proto_input mindspore::protobuf
|
||||
mindspore::event mindspore::event_pthreads ${zeromq_DIRPATH}/zmq_install/lib/libzmq.a)
|
||||
target_link_libraries(mindspore proto_input mindspore::protobuf
|
||||
mindspore::event mindspore::event_pthreads)
|
||||
target_link_libraries(mindspore -Wl,--no-as-needed mindspore::event_core ps_cache)
|
||||
if(${ENABLE_IBVERBS} STREQUAL "ON")
|
||||
target_link_libraries(mindspore ibverbs rdmacm)
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "ps/worker.h"
|
||||
#include "ps/util.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -35,7 +36,7 @@ void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
<< input_shape << " is too large.";
|
||||
}
|
||||
|
||||
if (mindspore::ps::Util::IsRoleOfWorker()) {
|
||||
if (mindspore::ps::PSContext::instance()->is_worker()) {
|
||||
key_ = AnfAlgo::GetNodeAttr<size_t>(kernel_node, kAttrPsKey);
|
||||
}
|
||||
std::vector<size_t> keys{key_, key_, key_};
|
||||
|
@ -50,9 +51,10 @@ void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
<< ", indices_shape:" << indices_shape << ", output_shape:" << output_shape;
|
||||
std::vector<int64_t> lens{SizeToLong(input_shape.size()), SizeToLong(indices_shape.size()),
|
||||
SizeToLong(output_shape.size())};
|
||||
if (mindspore::ps::Util::IsRoleOfWorker()) {
|
||||
if (mindspore::ps::PSContext::instance()->is_worker()) {
|
||||
mindspore::ps::worker.AddEmbeddingTable(key_, input_shape[axis]);
|
||||
mindspore::ps::worker.InitPSEmbeddingTable(keys, values, lens);
|
||||
mindspore::ps::ParamInitInfoMessage info;
|
||||
mindspore::ps::worker.InitPSEmbeddingTable(key_, input_shape, indices_shape, output_shape, info);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -70,17 +72,16 @@ bool EmbeddingLookUpProxyKernel::Launch(const std::vector<kernel::AddressPtr> &i
|
|||
size_t input_size = inputs[1]->size;
|
||||
size_t output_size = outputs[0]->size;
|
||||
|
||||
size_t size = input_size / sizeof(float);
|
||||
::ps::SArray<int> lookup_ids(size, 0);
|
||||
::ps::SArray<int> lengths{size};
|
||||
::ps::SArray<float> lookup_result(output_size / sizeof(float), 0);
|
||||
size_t size = input_size / sizeof(int);
|
||||
std::vector<int> lookup_ids(size, 0);
|
||||
std::vector<int> lengths{SizeToInt(size)};
|
||||
std::vector<float> lookup_result(output_size / sizeof(float), 0);
|
||||
auto ret = memcpy_s(lookup_ids.data(), lookup_ids.size() * sizeof(int), indices_addr, input_size);
|
||||
if (ret != EOK) {
|
||||
MS_LOG(EXCEPTION) << "Lookup id memcpy failed.";
|
||||
return false;
|
||||
}
|
||||
mindspore::ps::worker.DoPSEmbeddingLookup({key_}, lookup_ids, lengths, &lookup_result,
|
||||
mindspore::ps::kEmbeddingLookupCmd);
|
||||
mindspore::ps::worker.DoPSEmbeddingLookup(key_, lookup_ids, &lookup_result, mindspore::ps::kEmbeddingLookupCmd);
|
||||
|
||||
auto ret2 = memcpy_s(output_addr, outputs[0]->size, lookup_result.data(), output_size);
|
||||
if (ret2 != EOK) {
|
||||
|
|
|
@ -62,7 +62,7 @@ class PullKernel : public CPUKernel {
|
|||
MS_EXCEPTION_IF_NULL(param_node);
|
||||
param_name_ = param_node->fullname_with_scope();
|
||||
|
||||
if (mindspore::ps::Util::IsRoleOfWorker()) {
|
||||
if (mindspore::ps::PSContext::instance()->is_worker()) {
|
||||
key_ = AnfAlgo::GetNodeAttr<size_t>(kernel_node, kAttrPsKey);
|
||||
}
|
||||
InitSizeLists();
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
#include "backend/optimizer/pass/replace_node_by_proxy.h"
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#include "ps/util.h"
|
||||
#include "ps/ps_context.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -75,9 +76,9 @@ GraphId CPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr
|
|||
MS_LOG(INFO) << "Set kernel info";
|
||||
SetKernelInfo(graph.get());
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
if (ps::Util::IsParamServerMode()) {
|
||||
if (ps::PSContext::instance()->is_ps_mode()) {
|
||||
AssignParamKey(graph);
|
||||
if (ps::Util::IsRoleOfWorker()) {
|
||||
if (ps::PSContext::instance()->is_worker()) {
|
||||
Optimize(graph);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -41,8 +41,9 @@
|
|||
#include "utils/trace_base.h"
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#include "ps/ps_cache/ps_cache_manager.h"
|
||||
#include "ps/common.h"
|
||||
#include "ps/constants.h"
|
||||
#include "ps/util.h"
|
||||
#include "ps/ps_context.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#endif
|
||||
|
||||
|
@ -2287,7 +2288,7 @@ void SessionBasic::RunOpHideNopNode(const KernelGraphPtr &kernel_graph) const {
|
|||
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
void SessionBasic::InitPsWorker(const KernelGraphPtr &kernel_graph) {
|
||||
if (!ps::Util::IsRoleOfWorker()) {
|
||||
if (!ps::PSContext::instance()->is_worker()) {
|
||||
return;
|
||||
}
|
||||
CheckPSModeConsistence(kernel_graph);
|
||||
|
@ -2384,7 +2385,7 @@ void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) {
|
|||
|
||||
void SessionBasic::InitPSParamAndOptim(const KernelGraphPtr &kernel_graph,
|
||||
const std::vector<tensor::TensorPtr> &inputs_const) {
|
||||
if (!ps::Util::IsRoleOfWorker()) {
|
||||
if (!ps::PSContext::instance()->is_worker()) {
|
||||
return;
|
||||
}
|
||||
std::vector<tensor::TensorPtr> inputs(inputs_const);
|
||||
|
|
|
@ -48,6 +48,7 @@
|
|||
#include "mindspore/core/utils/parallel_node_check.h"
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#include "ps/util.h"
|
||||
#include "ps/ps_context.h"
|
||||
#endif
|
||||
|
||||
using mindspore::tensor::Tensor;
|
||||
|
@ -3283,7 +3284,7 @@ static void HandleNoUsedParameter(const FuncGraphPtr &root) {
|
|||
|
||||
bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) {
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
if (ps::Util::IsRoleOfPServer() || ps::Util::IsRoleOfScheduler()) {
|
||||
if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) {
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -288,7 +288,6 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Windows")
|
|||
else()
|
||||
target_link_libraries(_c_dataengine PRIVATE _c_mindrecord)
|
||||
if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))
|
||||
target_link_libraries(_c_dataengine PRIVATE mindspore::pslite ${zeromq_DIRPATH}/zmq_install/lib/libzmq.a)
|
||||
if(${ENABLE_IBVERBS} STREQUAL "ON")
|
||||
target_link_libraries(_c_dataengine PRIVATE ibverbs rdmacm)
|
||||
endif()
|
||||
|
|
|
@ -460,7 +460,7 @@ bool StartPSWorkerAction(const ResourcePtr &res) {
|
|||
|
||||
bool StartPSServerAction(const ResourcePtr &res) {
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
auto &ps = ps::ParameterServer<float>::GetInstance();
|
||||
auto &ps = ps::ParameterServer::GetInstance();
|
||||
ps.Run(func_graph);
|
||||
return true;
|
||||
}
|
||||
|
@ -626,7 +626,7 @@ std::vector<ActionItem> VmPipeline() {
|
|||
|
||||
actions.emplace_back(std::make_pair("validate", ValidateAction));
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
if (ps::Util::IsRoleOfWorker()) {
|
||||
if (ps::PSContext::instance()->is_worker()) {
|
||||
actions.emplace_back(std::make_pair("worker", StartPSWorkerAction));
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -43,6 +43,7 @@
|
|||
#include "pipeline/jit/static_analysis/auto_monad.h"
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#include "ps/util.h"
|
||||
#include "ps/ps_context.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -406,7 +407,7 @@ bool AddRecomputationPass(const ResourcePtr &res) {
|
|||
|
||||
bool AddCacheEmbeddingPass(const ResourcePtr &res) {
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
if (ps::Util::IsParamServerMode()) {
|
||||
if (ps::PSContext::instance()->is_ps_mode()) {
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -49,7 +49,7 @@
|
|||
#include "utils/shape_utils.h"
|
||||
#include "utils/info.h"
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#include "ps/common.h"
|
||||
#include "ps/constants.h"
|
||||
#include "ps/util.h"
|
||||
#include "ps/worker.h"
|
||||
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
|
||||
|
@ -492,14 +492,11 @@ std::vector<ActionItem> GetPipline(const ResourcePtr &resource, const std::strin
|
|||
std::string backend = MsContext::GetInstance()->backend_policy();
|
||||
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
if (mindspore::ps::Util::IsParamServerMode()) {
|
||||
mindspore::ps::Util::SetInternalEnvVar();
|
||||
}
|
||||
if (ps::Util::IsRoleOfPServer()) {
|
||||
if (ps::PSContext::instance()->is_server()) {
|
||||
resource->results()[kBackend] = compile::CreateBackend();
|
||||
return PServerPipeline();
|
||||
}
|
||||
if (ps::Util::IsRoleOfScheduler()) {
|
||||
if (ps::PSContext::instance()->is_scheduler()) {
|
||||
return PSchedulerPipeline();
|
||||
}
|
||||
#endif
|
||||
|
@ -978,7 +975,7 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc
|
|||
const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
|
||||
const std::vector<int64_t> &input_indexes, bool need_run) {
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
if ((ps::Util::IsParamServerMode()) && (!ps::Util::IsRoleOfWorker())) {
|
||||
if ((ps::PSContext::instance()->is_ps_mode()) && (!ps::PSContext::instance()->is_worker())) {
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
@ -1030,7 +1027,7 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc
|
|||
ConfigManager::GetInstance().set_iter_num(size);
|
||||
// PS cache does not support loop sink.
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
if (ps::Util::IsRoleOfWorker() && ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
||||
if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
||||
ps::PsDataPrefetch::GetInstance().CreateDataChannel(queue_name, LongToSize(size));
|
||||
ConfigManager::GetInstance().set_iter_num(1);
|
||||
}
|
||||
|
@ -1151,10 +1148,11 @@ void ClearResAtexit() {
|
|||
pynative::ClearPyNativeSession();
|
||||
session::ClearPythonParasMap();
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
if (ps::Util::IsParamServerMode() && ps::Util::IsRoleOfWorker()) {
|
||||
if (ps::PSContext::instance()->is_ps_mode() && ps::PSContext::instance()->is_worker()) {
|
||||
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
||||
ps::ps_cache_instance.Finalize();
|
||||
}
|
||||
MS_LOG(INFO) << "ps::worker.Finalize";
|
||||
ps::worker.Finalize();
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -21,8 +21,8 @@ if(NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)))
|
|||
list(REMOVE_ITEM _PS_SRC_FILES "core/abstract_node.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "core/scheduler_node.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "core/http_client.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "internal/worker.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "internal/parameter_server.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "worker.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "parameter_server.cc")
|
||||
endif()
|
||||
|
||||
if(NOT ENABLE_D)
|
||||
|
|
|
@ -1,140 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_COMMON_H_
|
||||
#define MINDSPORE_CCSRC_PS_COMMON_H_
|
||||
|
||||
#include <limits.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
#include "ps/ps.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
constexpr char kEnvCommType[] = "MS_COMM_TYPE";
|
||||
constexpr char kEnvInterface[] = "MS_INTERFACE";
|
||||
constexpr char kEnvPServerNum[] = "MS_SERVER_NUM";
|
||||
constexpr char kEnvWorkerNum[] = "MS_WORKER_NUM";
|
||||
constexpr char kEnvSchedulerHost[] = "MS_SCHED_HOST";
|
||||
constexpr char kEnvSchedulerPort[] = "MS_SCHED_PORT";
|
||||
|
||||
constexpr char kDmlcCommType[] = "DMLC_PS_VAN_TYPE";
|
||||
constexpr char kDmlcInterface[] = "DMLC_INTERFACE";
|
||||
constexpr char kDmlcPServerNum[] = "DMLC_NUM_SERVER";
|
||||
constexpr char kDmlcWorkerNum[] = "DMLC_NUM_WORKER";
|
||||
constexpr char kDmlcRole[] = "DMLC_ROLE";
|
||||
constexpr char kDmlcSchedulerHost[] = "DMLC_PS_ROOT_URI";
|
||||
constexpr char kDmlcSchedulerPort[] = "DMLC_PS_ROOT_PORT";
|
||||
|
||||
constexpr char kCommTypeOfIBVerbs[] = "ibverbs";
|
||||
constexpr char kCommTypeOfTCP[] = "zmq";
|
||||
constexpr char kRoleOfPServer[] = "server";
|
||||
constexpr char kRoleOfWorker[] = "worker";
|
||||
constexpr char kRoleOfScheduler[] = "scheduler";
|
||||
|
||||
constexpr char kLearningRate[] = "learning_rate";
|
||||
constexpr char kMomentum[] = "momentum";
|
||||
|
||||
constexpr char kApplyMomentum[] = "ApplyMomentum";
|
||||
constexpr char kSparseAdam[] = "Adam";
|
||||
constexpr char kSparseLazyAdam[] = "LazyAdam";
|
||||
constexpr char kSparseFtrl[] = "Ftrl";
|
||||
constexpr char kApplyMomentumOp[] = "Momentum";
|
||||
constexpr char kSparseAdamOp[] = "Adam";
|
||||
constexpr char kSparseLazyAdamOp[] = "LazyAdam";
|
||||
constexpr char kSparseFtrlOp[] = "FTRL";
|
||||
|
||||
constexpr int64_t kInitWeightsCmd = 10;
|
||||
constexpr int64_t kInitWeightToOptimIdCmd = 11;
|
||||
constexpr int64_t kInitOptimInputsShapeCmd = 12;
|
||||
constexpr int64_t kInitKeyToPushNodeIdCmd = 13;
|
||||
constexpr int64_t kInitEmbeddingsCmd = 20;
|
||||
constexpr int64_t kUpdateEmbeddingsCmd = 21;
|
||||
constexpr int64_t kCheckReadyForPushCmd = 25;
|
||||
constexpr int64_t kCheckReadyForPullCmd = 26;
|
||||
constexpr int64_t kEmbeddingLookupCmd = 30;
|
||||
constexpr int64_t kFinalizeCmd = 40;
|
||||
|
||||
constexpr size_t kInvalidKey = UINT64_MAX;
|
||||
constexpr int64_t kInvalidID = -1;
|
||||
|
||||
using DataPtr = std::shared_ptr<unsigned char>;
|
||||
using VectorPtr = std::shared_ptr<std::vector<unsigned char>>;
|
||||
using Key = ::ps::Key;
|
||||
using Keys = ::ps::SArray<Key>;
|
||||
using Values = ::ps::SArray<float>;
|
||||
using ValuesPtr = std::shared_ptr<Values>;
|
||||
using Weight = ::ps::SArray<float>;
|
||||
using Grad = ::ps::SArray<float>;
|
||||
using LookupIds = ::ps::SArray<Key>;
|
||||
using Lengths = ::ps::SArray<int>;
|
||||
using WeightPtr = std::shared_ptr<Weight>;
|
||||
using GradPtr = std::shared_ptr<Grad>;
|
||||
using InputsShape = std::vector<std::shared_ptr<std::vector<size_t>>>;
|
||||
using InputsShapePtr = std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>>;
|
||||
|
||||
constexpr size_t INDEX_NOT_SEND = UINT_MAX;
|
||||
using OptimOriginIdx = std::map<std::string, size_t>;
|
||||
using OptimPSSendIdx = std::map<std::string, size_t>;
|
||||
|
||||
const OptimOriginIdx kMomentumOriginIdx = {{"weight", 0}, {"accum", 1}, {"lr", 2}, {"grad", 3}, {"momentum", 4}};
|
||||
const OptimPSSendIdx kMomentumPSSendIdx = {
|
||||
{"weight", INDEX_NOT_SEND}, {"accum", INDEX_NOT_SEND}, {"lr", 0}, {"grad", 1}, {"momentum", 2}};
|
||||
|
||||
const OptimOriginIdx kSparseAdamOriginIdx = {{"weight", 0}, {"m", 1}, {"v", 2}, {"beta1_power", 3},
|
||||
{"beta2_power", 4}, {"lr", 5}, {"beta1", 6}, {"beta2", 7},
|
||||
{"eps", 8}, {"grad", 9}, {"indices", 10}};
|
||||
const OptimPSSendIdx kSparseAdamPSSendIdx = {{"weight", INDEX_NOT_SEND},
|
||||
{"m", INDEX_NOT_SEND},
|
||||
{"v", INDEX_NOT_SEND},
|
||||
{"beta1_power", 0},
|
||||
{"beta2_power", 1},
|
||||
{"lr", 2},
|
||||
{"beta1", 3},
|
||||
{"beta2", 4},
|
||||
{"eps", 5},
|
||||
{"grad", 6},
|
||||
{"indices", 7}};
|
||||
|
||||
const OptimOriginIdx kSparseFtrlOriginIdx = {{"weight", 0}, {"accum", 1}, {"linear", 2}, {"grad", 3}, {"indices", 4}};
|
||||
const OptimPSSendIdx kSparseFtrlPSSendIdx = {
|
||||
{"weight", INDEX_NOT_SEND}, {"accum", INDEX_NOT_SEND}, {"linear", INDEX_NOT_SEND}, {"grad", 0}, {"indices", 1}};
|
||||
|
||||
const std::map<std::string, OptimOriginIdx> kOptimToOriginIdx = {{kApplyMomentum, kMomentumOriginIdx},
|
||||
{kSparseAdam, kSparseAdamOriginIdx},
|
||||
{kSparseLazyAdam, kSparseAdamOriginIdx},
|
||||
{kSparseFtrl, kSparseFtrlOriginIdx}};
|
||||
const std::map<std::string, OptimOriginIdx> kOptimToPSSendIdx = {{kApplyMomentum, kMomentumPSSendIdx},
|
||||
{kSparseAdam, kSparseAdamPSSendIdx},
|
||||
{kSparseLazyAdam, kSparseAdamPSSendIdx},
|
||||
{kSparseFtrl, kSparseFtrlPSSendIdx}};
|
||||
|
||||
#define EXC_IF_VEC_IDX_OOB(vec, idx) \
|
||||
{ \
|
||||
size_t vec_size = vec.size(); \
|
||||
if (idx >= vec_size) { \
|
||||
MS_LOG(EXCEPTION) << "Vector " << #vec << " size is " << vec_size << ". So index " << idx \
|
||||
<< " is out of bound."; \
|
||||
} \
|
||||
}
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_COMMON_H_
|
|
@ -14,10 +14,11 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_INTERNAL_CONSTANTS_H_
|
||||
#define MINDSPORE_CCSRC_PS_INTERNAL_CONSTANTS_H_
|
||||
#ifndef MINDSPORE_CCSRC_PS_CONSTANTS_H_
|
||||
#define MINDSPORE_CCSRC_PS_CONSTANTS_H_
|
||||
|
||||
#include <limits.h>
|
||||
|
||||
#include <climits>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
@ -26,8 +27,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace internal {
|
||||
|
||||
constexpr char kEnvCommType[] = "MS_COMM_TYPE";
|
||||
constexpr char kEnvInterface[] = "MS_INTERFACE";
|
||||
constexpr char kEnvPServerNum[] = "MS_SERVER_NUM";
|
||||
|
@ -127,7 +126,6 @@ const std::map<std::string, OptimOriginIdx> kOptimToPSSendIdx = {{kApplyMomentum
|
|||
<< " is out of bound."; \
|
||||
} \
|
||||
}
|
||||
} // namespace internal
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_INTERNAL_CONSTANTS_H_
|
||||
#endif // MINDSPORE_CCSRC_PS_CONSTANTS_H_
|
|
@ -39,9 +39,9 @@ void ClusterMetadata::Init(const uint32_t &worker_num, const uint32_t &server_nu
|
|||
scheduler_port_ = scheduler_port;
|
||||
}
|
||||
|
||||
uint32_t ClusterMetadata::worker_num() { return worker_num_; }
|
||||
uint32_t ClusterMetadata::total_worker_num() { return worker_num_; }
|
||||
|
||||
uint32_t ClusterMetadata::server_num() { return server_num_; }
|
||||
uint32_t ClusterMetadata::total_server_num() { return server_num_; }
|
||||
|
||||
uint32_t ClusterMetadata::heartbeat_interval() { return heartbeat_interval_; }
|
||||
|
||||
|
|
|
@ -37,8 +37,8 @@ class ClusterMetadata {
|
|||
|
||||
void Init(const uint32_t &worker_num, const uint32_t &server_num, std::string scheduler_host,
|
||||
const uint16_t &scheduler_port);
|
||||
uint32_t worker_num();
|
||||
uint32_t server_num();
|
||||
uint32_t total_worker_num();
|
||||
uint32_t total_server_num();
|
||||
uint32_t heartbeat_interval();
|
||||
void set_heartbeat_interval(const uint32_t &heartbeat_interval);
|
||||
std::string scheduler_host();
|
||||
|
|
|
@ -122,9 +122,9 @@ std::string CommUtil::NodeRoleToString(const NodeRole &role) {
|
|||
}
|
||||
}
|
||||
bool CommUtil::ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id) {
|
||||
if (node_role == NodeRole::SERVER && (rank_id > ClusterMetadata::instance()->server_num() - 1)) {
|
||||
if (node_role == NodeRole::SERVER && (rank_id > ClusterMetadata::instance()->total_server_num() - 1)) {
|
||||
return false;
|
||||
} else if (node_role == NodeRole::WORKER && (rank_id > ClusterMetadata::instance()->worker_num() - 1)) {
|
||||
} else if (node_role == NodeRole::WORKER && (rank_id > ClusterMetadata::instance()->total_worker_num() - 1)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
|
|
@ -20,7 +20,7 @@ namespace mindspore {
|
|||
namespace ps {
|
||||
namespace core {
|
||||
void NodeManager::InitNodeNum() {
|
||||
total_node_num_ = ClusterMetadata::instance()->server_num() + ClusterMetadata::instance()->worker_num();
|
||||
total_node_num_ = ClusterMetadata::instance()->total_server_num() + ClusterMetadata::instance()->total_worker_num();
|
||||
}
|
||||
|
||||
int NodeManager::NextRankId(const RegisterMessage ®ister_message) {
|
||||
|
|
|
@ -1,179 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_INTERNAL_PARAMETER_SERVER_H_
|
||||
#define MINDSPORE_CCSRC_PS_INTERNAL_PARAMETER_SERVER_H_
|
||||
|
||||
#include <unistd.h>
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <mutex>
|
||||
#include <condition_variable>
|
||||
#include <thread>
|
||||
#include <cmath>
|
||||
#include <random>
|
||||
#include <utility>
|
||||
#include <list>
|
||||
#include <map>
|
||||
#include <functional>
|
||||
#include "ir/func_graph.h"
|
||||
#include "backend/session/session_basic.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/session/session_factory.h"
|
||||
#include "ps/optimizer_info.h"
|
||||
#include "ps/optimizer_info_builder.h"
|
||||
#include "ps/ps_context.h"
|
||||
#include "runtime/device/cpu/kernel_select_cpu.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "backend/kernel_compiler/kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/cpu/ps/pserver_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h"
|
||||
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
|
||||
#include "ps/random_normal/random_normal.h"
|
||||
|
||||
#include "ps/internal/constants.h"
|
||||
#include "ps/util.h"
|
||||
#include "ps/embedding_table_shard_metadata.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "proto/comm.pb.h"
|
||||
#include "proto/ps.pb.h"
|
||||
#include "ps/core/server_node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace internal {
|
||||
|
||||
class ParameterServer {
|
||||
public:
|
||||
static ParameterServer &GetInstance() {
|
||||
static ParameterServer instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void Run(const FuncGraphPtr &func_graph);
|
||||
|
||||
private:
|
||||
ParameterServer()
|
||||
: pserver_num_(0),
|
||||
worker_num_(0),
|
||||
rank_id_(0),
|
||||
grad_accum_count_(0),
|
||||
handler_(nullptr),
|
||||
func_graph_(nullptr),
|
||||
sess_(nullptr),
|
||||
running_(true),
|
||||
thread_(nullptr) {}
|
||||
~ParameterServer() = default;
|
||||
ParameterServer(const ParameterServer &) = delete;
|
||||
ParameterServer &operator=(const ParameterServer &) = delete;
|
||||
|
||||
class ServerHandler {
|
||||
public:
|
||||
explicit ServerHandler(ParameterServer *ps) : ps_(ps) {}
|
||||
~ServerHandler() = default;
|
||||
void Init();
|
||||
void operator()(std::shared_ptr<core::TcpConnection> conn, std::shared_ptr<core::MessageMeta> meta, DataPtr data,
|
||||
size_t size);
|
||||
void HandlePushReq(DataPtr data, size_t size, VectorPtr res);
|
||||
void HandlePullReq(DataPtr data, size_t size, VectorPtr res);
|
||||
void HandleInitWeights(DataPtr data, size_t size, VectorPtr res);
|
||||
void HandleInitWeightToOptimId(DataPtr data, size_t size, VectorPtr res);
|
||||
void HandleInitInputsShape(DataPtr data, size_t size, VectorPtr res);
|
||||
void HandleInitEmbeddings(DataPtr data, size_t size, VectorPtr res);
|
||||
void HandleCheckReadyForPush(DataPtr data, size_t size, VectorPtr res);
|
||||
void HandleCheckReadyForPull(DataPtr data, size_t size, VectorPtr res);
|
||||
void HandleEmbeddingLookup(DataPtr data, size_t size, VectorPtr res);
|
||||
void HandleUpdateEmbeddings(DataPtr data, size_t size, VectorPtr res);
|
||||
void HandleFinalize(DataPtr data, size_t size, VectorPtr res);
|
||||
|
||||
private:
|
||||
ParameterServer *ps_;
|
||||
typedef void (ServerHandler::*RequestHandler)(DataPtr data, size_t size, VectorPtr res);
|
||||
std::unordered_map<int, RequestHandler> handlers_;
|
||||
std::unordered_map<Key, bool> init_weights_;
|
||||
std::unordered_map<Key, bool> init_weight_to_optim_;
|
||||
std::unordered_map<Key, bool> init_optim_info_;
|
||||
};
|
||||
|
||||
bool Init(const FuncGraphPtr &func_graph);
|
||||
void InitOptimInfoBuilders();
|
||||
void InitWeightKeyToOptims(const Key &key, const int64_t &optim_id);
|
||||
void InitOptimInputsShape(const Keys &keys, const Values &values, const Lengths &lengths);
|
||||
void InitWeight(const Key &key, const WeightPtr &weight);
|
||||
void InitGrad(const Key &key, const GradPtr &grad);
|
||||
void InitEmbeddingTable(const Key &key,
|
||||
const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes,
|
||||
const ParamInitInfo ¶m_init_info);
|
||||
bool HasWeight(const Key &key);
|
||||
void Finalize();
|
||||
void UpdateWeights();
|
||||
void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths);
|
||||
WeightPtr weight(const Key &key);
|
||||
void DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, KVMessage *res);
|
||||
void UpdateEmbeddings(const Key &key, const LookupIds &lookup_ids, const Values &vals);
|
||||
bool ReadyForUpdateWeights();
|
||||
bool ReadyForPush(const Key &key);
|
||||
bool ReadyForPull(const Key &key);
|
||||
void ResetGradAccumCount();
|
||||
const CNodePtr GetCNode(const std::string &name) const;
|
||||
std::mutex &mutex();
|
||||
void GetEmbeddingTableParamPtr();
|
||||
void SyncEmbeddingTables();
|
||||
|
||||
size_t pserver_num_;
|
||||
size_t worker_num_;
|
||||
size_t rank_id_;
|
||||
size_t grad_accum_count_;
|
||||
std::unique_ptr<ServerHandler> handler_;
|
||||
FuncGraphPtr func_graph_;
|
||||
std::shared_ptr<session::SessionBasic> sess_;
|
||||
bool running_;
|
||||
|
||||
std::unordered_map<Key, std::shared_ptr<PServerKernel>> optimizers_;
|
||||
std::unordered_map<Key, InputsShapePtr> optim_inputs_shape_;
|
||||
std::unordered_map<Key, InputsShapePtr> original_optim_inputs_shape_;
|
||||
std::unordered_map<Key, std::shared_ptr<OptimizerInfo>> optim_infos_;
|
||||
std::unordered_map<std::string, std::shared_ptr<OptimizerInfoBuilder>> optim_info_builders_;
|
||||
std::unordered_map<Key, std::string> weight_key_to_optims_;
|
||||
std::unordered_map<Key, std::string> weight_key_to_optim_op_;
|
||||
std::unordered_map<Key, WeightPtr> weights_;
|
||||
std::unordered_map<Key, bool> is_embedding_;
|
||||
std::unordered_map<Key, WeightPtr> grads_;
|
||||
std::unordered_map<Key, size_t> grads_accum_counter_;
|
||||
std::unordered_map<Key, std::shared_ptr<PServerKernel>> embedding_lookup_ops_;
|
||||
std::unordered_map<Key, uint64_t> tokens_;
|
||||
|
||||
std::mutex mutex_;
|
||||
std::condition_variable apply_grads_cv_;
|
||||
|
||||
std::unique_ptr<std::thread> thread_;
|
||||
core::ServerNode server_node_;
|
||||
std::map<Key, ParameterPtr> embedding_tables_;
|
||||
|
||||
friend class ServerHandler;
|
||||
};
|
||||
} // namespace internal
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_INTERNAL_PARAMETER_SERVER_H_
|
|
@ -1,157 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_INTERNAL_WORKER_H_
|
||||
#define MINDSPORE_CCSRC_PS_INTERNAL_WORKER_H_
|
||||
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <numeric>
|
||||
#include <functional>
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <unordered_set>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "utils/log_adapter.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "ps/util.h"
|
||||
#include "ps/internal/constants.h"
|
||||
#include "utils/shape_utils.h"
|
||||
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
|
||||
#include "ps/core/worker_node.h"
|
||||
#include "ps/embedding_table_shard_metadata.h"
|
||||
#include "proto/comm.pb.h"
|
||||
#include "proto/ps.pb.h"
|
||||
#include "ps/ps_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace internal {
|
||||
|
||||
class Worker {
|
||||
public:
|
||||
static Worker &GetInstance() {
|
||||
static Worker instance;
|
||||
return instance;
|
||||
}
|
||||
using Callback = std::function<void()>;
|
||||
using PartitionEmbeddingMessages = std::vector<std::pair<bool, EmbeddingTableLookup>>;
|
||||
using PartitionKVMessages = std::vector<std::pair<bool, KVMessage>>;
|
||||
|
||||
using EmbeddingPartitioner = std::function<void(
|
||||
const EmbeddingTableLookup &send, PartitionEmbeddingMessages *partition, const std::map<int64_t, int64_t> &attrs)>;
|
||||
using KVPartitioner =
|
||||
std::function<void(const KVMessage &send, PartitionKVMessages *partition, const std::map<int64_t, int64_t> &attrs)>;
|
||||
|
||||
void Run();
|
||||
void Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs, const ShapeVector &sizes);
|
||||
void Pull(const size_t key, void *dev_addr, const size_t size);
|
||||
size_t SetParamKey(const std::string ¶m_name);
|
||||
size_t GetParamKey(const std::string ¶m_name);
|
||||
void SetParamInitInServer(const std::string ¶m_name, bool init_in_server);
|
||||
bool GetParamInitInServer(const std::string ¶m_name);
|
||||
void SetKeyOptimId(size_t key, const std::string &optimizer_name);
|
||||
void SetOptimInputShapes(size_t key, const ShapeVector &shape);
|
||||
void AddEmbeddingTable(const Key &key, const size_t &row_count);
|
||||
void InitPSEmbeddingTable(const size_t &key, const std::vector<size_t> &input_shape,
|
||||
const std::vector<size_t> &indices_shape, const std::vector<size_t> &output_shape);
|
||||
void InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor::TensorPtr &tensor);
|
||||
void DoPSEmbeddingLookup(const Key &key, const std::vector<int> &lookup_ids, std::vector<float> *lookup_result,
|
||||
int64_t cmd);
|
||||
void UpdateEmbeddingTable(const std::vector<Key> &keys, const std::vector<int> &lookup_ids,
|
||||
const std::vector<float> &vals);
|
||||
|
||||
bool running() { return running_; }
|
||||
void Finalize();
|
||||
|
||||
private:
|
||||
Worker() : running_(false), key_cnt_(0) {}
|
||||
~Worker() = default;
|
||||
Worker(const Worker &) = delete;
|
||||
Worker &operator=(const Worker &) = delete;
|
||||
|
||||
void Initialize();
|
||||
bool IsKeyInit(const size_t key);
|
||||
void AddKeyToServerId(const Key &key);
|
||||
void AddKeyByHashMod(const Key &key);
|
||||
void InitPSOptimId(const size_t param_key);
|
||||
void InitPSOptimInputShapes(const size_t key);
|
||||
void InitPSParamData(const std::vector<size_t> &keys, void *origin_addr, size_t size);
|
||||
bool IsReadyForPush(const Key &key);
|
||||
bool IsReadyForPull(const Key &key);
|
||||
void PrepareSparseGradient(const size_t begin, const size_t end, const std::unordered_set<int> &distinct_ids,
|
||||
const std::vector<std::pair<int, float *>> &indice_to_grads, const int *all_indice,
|
||||
const size_t segment_size, float *gradient, int *indices);
|
||||
void BuildSparseValue(const std::vector<int> &lengths, const size_t grad_index, const size_t indice_index,
|
||||
const float *original_data, const float *grads, int *indices, std::vector<float> *reduced_data);
|
||||
|
||||
void PushData(const std::vector<Key> &keys, const std::vector<float> &vals, const std::vector<int> &lens = {},
|
||||
int command = 0, int64_t priority = 0);
|
||||
void PushSparseData(const std::vector<Key> &keys, const std::vector<float> &vals, const std::vector<int> &lens,
|
||||
size_t grad_index, size_t indice_index, size_t first_dim_size, size_t outer_dim_size);
|
||||
void PullData(const std::vector<Key> &keys, std::vector<float> *vals, std::vector<int> *lens = nullptr, int cmd = 0,
|
||||
int64_t priority = 0);
|
||||
|
||||
void LookupIdPartitioner(const EmbeddingTableLookup &send, PartitionEmbeddingMessages *partition,
|
||||
const std::map<int64_t, int64_t> &attrs);
|
||||
|
||||
void SparsePartitioner(const KVMessage &send, PartitionKVMessages *partition,
|
||||
const std::map<int64_t, int64_t> &attrs);
|
||||
void RoundRobinPartitioner(const KVMessage &send, PartitionKVMessages *partition,
|
||||
const std::map<int64_t, int64_t> &attrs);
|
||||
void WorkerInitEmbeddingPartitioner(const KVMessage &send, std::vector<std::pair<bool, KVMessage>> *partition,
|
||||
const std::map<int64_t, int64_t> &attrs);
|
||||
void UpdateEmbeddingPartitioner(const KVMessage &send, PartitionKVMessages *partition,
|
||||
const std::map<int64_t, int64_t> &attrs);
|
||||
void BroadcastPartitioner(const KVMessage &send, PartitionKVMessages *partition,
|
||||
const std::map<int64_t, int64_t> &attrs);
|
||||
void SendForPush(int cmd, const KVMessage &send, const KVPartitioner &partitioner,
|
||||
const std::map<int64_t, int64_t> &attrs);
|
||||
void SendForPull(int cmd, const KVMessage &send, const KVPartitioner &partitioner,
|
||||
const std::map<int64_t, int64_t> &attrs, std::vector<float> *vals, std::vector<int> *lens);
|
||||
|
||||
int64_t server_num_;
|
||||
bool running_;
|
||||
std::mutex running_mutex_;
|
||||
size_t key_cnt_;
|
||||
std::map<std::string, size_t> param_to_key_;
|
||||
std::map<size_t, bool> init_keys_;
|
||||
std::map<size_t, int64_t> key_to_optimId_;
|
||||
std::map<size_t, std::vector<ShapeVector>> key_to_optim_shapes_;
|
||||
std::map<std::string, bool> param_to_init_in_server_;
|
||||
core::WorkerNode worker_node_;
|
||||
|
||||
EmbeddingPartitioner lookup_partitioner_;
|
||||
KVPartitioner sparse_partitioner_;
|
||||
KVPartitioner round_robin_partitioner_;
|
||||
KVPartitioner worker_init_embedding_partitioner_;
|
||||
KVPartitioner update_embedding_partitioner_;
|
||||
KVPartitioner broadcast_partitioner_;
|
||||
std::unordered_map<Key, int64_t> key_to_server_id_;
|
||||
std::unordered_map<Key, size_t> embedding_row_cnt_;
|
||||
|
||||
std::unordered_map<Key, std::shared_ptr<std::vector<EmbeddingTableShardMetadata>>> embedding_table_ranges_;
|
||||
};
|
||||
|
||||
static Worker &worker = Worker::GetInstance();
|
||||
} // namespace internal
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_INTERNAL_WORKER_H_
|
|
@ -84,7 +84,7 @@ void DenseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) {
|
|||
for (size_t i = 0; i < grad_index; i++) {
|
||||
grad_offset += lengths[i];
|
||||
}
|
||||
float *grad_data = values.data() + grad_offset;
|
||||
float *grad_data = const_cast<float *>(values.data()) + grad_offset;
|
||||
CHECK_EQ(size, static_cast<size_t>(lengths[grad_index]));
|
||||
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
|
@ -121,7 +121,7 @@ void SparseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) {
|
|||
for (size_t i = 0; i < grad_index; i++) {
|
||||
grad_offset += lengths[i];
|
||||
}
|
||||
float *incr_grad_data = values.data() + grad_offset;
|
||||
float *incr_grad_data = const_cast<float *>(values.data()) + grad_offset;
|
||||
MS_EXCEPTION_IF_NULL(incr_grad_data);
|
||||
|
||||
size_t incr_grad_size = lengths[grad_index] * sizeof(float);
|
||||
|
@ -148,7 +148,11 @@ void SparseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) {
|
|||
for (size_t i = 0; i < indices_index; i++) {
|
||||
indice_offset += lengths[i];
|
||||
}
|
||||
int *incr_indice_data = reinterpret_cast<int *>(values.data()) + indice_offset;
|
||||
|
||||
void *incr_indice_data_temp = const_cast<float *>(values.data()) + indice_offset;
|
||||
|
||||
int *incr_indice_data = reinterpret_cast<int *>(incr_indice_data_temp);
|
||||
|
||||
MS_EXCEPTION_IF_NULL(incr_indice_data);
|
||||
size_t incr_indice_size = lengths[indices_index];
|
||||
size_t incr_indice_data_size = incr_indice_size * sizeof(int);
|
||||
|
@ -259,7 +263,7 @@ MomentumOptimInfo::MomentumOptimInfo(const AddressPtr &weight, const AddressPtr
|
|||
}
|
||||
|
||||
void MomentumOptimInfo::Update(const Values &values, const Lengths &lens) {
|
||||
UpdateOptimInputValue<float>(kApplyMomentum, "lr", values.data(), lens);
|
||||
UpdateOptimInputValue<float>(kApplyMomentum, "lr", const_cast<float *>(values.data()), lens);
|
||||
}
|
||||
|
||||
const size_t SparseOptimInfo::indice_size() const { return indices_offset_; }
|
||||
|
@ -303,12 +307,12 @@ SparseAdamOptimInfo::SparseAdamOptimInfo(const AddressPtr &weight, const Address
|
|||
}
|
||||
|
||||
void SparseAdamOptimInfo::Update(const Values &values, const Lengths &lens) {
|
||||
UpdateOptimInputValue<float>(kSparseAdam, "beta1_power", values.data(), lens);
|
||||
UpdateOptimInputValue<float>(kSparseAdam, "beta2_power", values.data(), lens);
|
||||
UpdateOptimInputValue<float>(kSparseAdam, "lr", values.data(), lens);
|
||||
UpdateOptimInputValue<float>(kSparseAdam, "beta1", values.data(), lens);
|
||||
UpdateOptimInputValue<float>(kSparseAdam, "beta2", values.data(), lens);
|
||||
UpdateOptimInputValue<float>(kSparseAdam, "eps", values.data(), lens);
|
||||
UpdateOptimInputValue<float>(kSparseAdam, "beta1_power", const_cast<float *>(values.data()), lens);
|
||||
UpdateOptimInputValue<float>(kSparseAdam, "beta2_power", const_cast<float *>(values.data()), lens);
|
||||
UpdateOptimInputValue<float>(kSparseAdam, "lr", const_cast<float *>(values.data()), lens);
|
||||
UpdateOptimInputValue<float>(kSparseAdam, "beta1", const_cast<float *>(values.data()), lens);
|
||||
UpdateOptimInputValue<float>(kSparseAdam, "beta2", const_cast<float *>(values.data()), lens);
|
||||
UpdateOptimInputValue<float>(kSparseAdam, "eps", const_cast<float *>(values.data()), lens);
|
||||
}
|
||||
|
||||
const AddressPtr &SparseAdamOptimInfo::gradient() {
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include "backend/kernel_compiler/kernel.h"
|
||||
#include "ps/common.h"
|
||||
#include "ps/constants.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
|
|
|
@ -129,9 +129,9 @@ OptimizerInfo *MomentumOptimInfoBuilder::BuildInputs(const WeightPtr &weight, co
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
AddressPtr learning_rate = GenInputAddrPtr<float>(kApplyMomentum, "lr", values.data(), lens);
|
||||
AddressPtr gradient = GenInputAddrPtr<float>(kApplyMomentum, "grad", values.data(), lens);
|
||||
AddressPtr momentum = GenInputAddrPtr<float>(kApplyMomentum, "momentum", values.data(), lens);
|
||||
AddressPtr learning_rate = GenInputAddrPtr<float>(kApplyMomentum, "lr", const_cast<float *>(values.data()), lens);
|
||||
AddressPtr gradient = GenInputAddrPtr<float>(kApplyMomentum, "grad", const_cast<float *>(values.data()), lens);
|
||||
AddressPtr momentum = GenInputAddrPtr<float>(kApplyMomentum, "momentum", const_cast<float *>(values.data()), lens);
|
||||
return new MomentumOptimInfo(weight_addr, accumulate, learning_rate, gradient, momentum);
|
||||
}
|
||||
|
||||
|
@ -172,14 +172,15 @@ OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight,
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
AddressPtr beta1_power = GenInputAddrPtr<float>(kSparseAdam, "beta1_power", values.data(), lens);
|
||||
AddressPtr beta2_power = GenInputAddrPtr<float>(kSparseAdam, "beta2_power", values.data(), lens);
|
||||
AddressPtr learning_rate = GenInputAddrPtr<float>(kSparseAdam, "lr", values.data(), lens);
|
||||
AddressPtr beta1 = GenInputAddrPtr<float>(kSparseAdam, "beta1", values.data(), lens);
|
||||
AddressPtr beta2 = GenInputAddrPtr<float>(kSparseAdam, "beta2", values.data(), lens);
|
||||
AddressPtr epsilon = GenInputAddrPtr<float>(kSparseAdam, "eps", values.data(), lens);
|
||||
AddressPtr grad = GenInputAddrPtr<float>(kSparseAdam, "grad", values.data(), lens, inputs_shape);
|
||||
AddressPtr indices = GenInputAddrPtr<float>(kSparseAdam, "indices", values.data(), lens, inputs_shape);
|
||||
AddressPtr beta1_power = GenInputAddrPtr<float>(kSparseAdam, "beta1_power", const_cast<float *>(values.data()), lens);
|
||||
AddressPtr beta2_power = GenInputAddrPtr<float>(kSparseAdam, "beta2_power", const_cast<float *>(values.data()), lens);
|
||||
AddressPtr learning_rate = GenInputAddrPtr<float>(kSparseAdam, "lr", const_cast<float *>(values.data()), lens);
|
||||
AddressPtr beta1 = GenInputAddrPtr<float>(kSparseAdam, "beta1", const_cast<float *>(values.data()), lens);
|
||||
AddressPtr beta2 = GenInputAddrPtr<float>(kSparseAdam, "beta2", const_cast<float *>(values.data()), lens);
|
||||
AddressPtr epsilon = GenInputAddrPtr<float>(kSparseAdam, "eps", const_cast<float *>(values.data()), lens);
|
||||
AddressPtr grad = GenInputAddrPtr<float>(kSparseAdam, "grad", const_cast<float *>(values.data()), lens, inputs_shape);
|
||||
AddressPtr indices =
|
||||
GenInputAddrPtr<float>(kSparseAdam, "indices", const_cast<float *>(values.data()), lens, inputs_shape);
|
||||
return new SparseAdamOptimInfo(weight_addr, m, v, beta1_power, beta2_power, learning_rate, beta1, beta2, epsilon,
|
||||
grad, indices, sharded);
|
||||
}
|
||||
|
@ -218,8 +219,9 @@ OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight,
|
|||
}
|
||||
linear->size = weight->size() * sizeof(float);
|
||||
|
||||
AddressPtr grad = GenInputAddrPtr<float>(kSparseFtrl, "grad", values.data(), lens, inputs_shape);
|
||||
AddressPtr indices = GenInputAddrPtr<float>(kSparseFtrl, "indices", values.data(), lens, inputs_shape);
|
||||
AddressPtr grad = GenInputAddrPtr<float>(kSparseFtrl, "grad", const_cast<float *>(values.data()), lens, inputs_shape);
|
||||
AddressPtr indices =
|
||||
GenInputAddrPtr<float>(kSparseFtrl, "indices", const_cast<float *>(values.data()), lens, inputs_shape);
|
||||
return new SparseFtrlOptimInfo(weight_addr, accum, linear, grad, indices, sharded);
|
||||
}
|
||||
} // namespace ps
|
||||
|
|
|
@ -14,12 +14,10 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/internal/parameter_server.h"
|
||||
#include "ps/parameter_server.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace internal {
|
||||
|
||||
void ParameterServer::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_LOG(INFO) << "PServer starts connecting to scheduler and workers...";
|
||||
|
@ -44,8 +42,8 @@ void ParameterServer::Run(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
|
||||
bool ParameterServer::Init(const FuncGraphPtr &func_graph) {
|
||||
pserver_num_ = std::strtol(mindspore::common::GetEnv(kEnvWorkerNum).c_str(), nullptr, 10);
|
||||
worker_num_ = std::strtol(mindspore::common::GetEnv(kEnvPServerNum).c_str(), nullptr, 10);
|
||||
pserver_num_ = std::strtol(mindspore::common::GetEnv(kEnvPServerNum).c_str(), nullptr, 10);
|
||||
worker_num_ = std::strtol(mindspore::common::GetEnv(kEnvWorkerNum).c_str(), nullptr, 10);
|
||||
func_graph_ = func_graph;
|
||||
handler_.reset(new ServerHandler(this));
|
||||
handler_->Init();
|
||||
|
@ -257,12 +255,21 @@ void ParameterServer::AccumGrad(const Keys &keys, const Values &values, const Le
|
|||
std::shared_ptr<OptimizerInfo> optim_info = optim_infos_[key];
|
||||
|
||||
// Create or update the optimizer info
|
||||
std::shared_ptr<kernel::ps::PServerKernel> pserver_kernel = optimizers_[key];
|
||||
if (pserver_kernel == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "no optimizer found for key " << key << " optim name " << weight_key_to_optims_[key];
|
||||
if (optim_info == nullptr) {
|
||||
const std::shared_ptr<OptimizerInfoBuilder> &builder = optim_info_builders_[weight_key_to_optims_[key]];
|
||||
std::shared_ptr<kernel::ps::PServerKernel> pserver_kernel = optimizers_[key];
|
||||
if (pserver_kernel == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "no optimizer found for key " << key << " optim name " << weight_key_to_optims_[key];
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(pserver_kernel);
|
||||
OptimizerInfo *optim = builder->Build(pserver_kernel, weights_[key], keys, values, lengths,
|
||||
optim_inputs_shape_[key], worker_num_, is_embedding_[key]);
|
||||
optim_info.reset(optim);
|
||||
optim_infos_[key] = optim_info;
|
||||
} else {
|
||||
optim_info->Update(values, lengths);
|
||||
optim_info->Accumulate(values, lengths);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(pserver_kernel);
|
||||
optim_infos_[key] = optim_info;
|
||||
}
|
||||
|
||||
grads_accum_counter_[key] += 1;
|
||||
|
@ -373,7 +380,7 @@ inline bool ParameterServer::ReadyForPush(const Key &key) {
|
|||
MS_LOG(EXCEPTION) << "The weights in server is empty. Many reasons could cause this: 1.The Worker didn't send "
|
||||
"kInitWeightsCmd command. 2.The Server failed to initialize weights.";
|
||||
}
|
||||
MS_LOG(INFO) << "the grad_accum_count_:" << grad_accum_count_ << " the weights_:" << weights_.size()
|
||||
MS_LOG(INFO) << "The grad_accum_count_:" << grad_accum_count_ << " the weights_:" << weights_.size()
|
||||
<< " the token:" << (tokens_[key] <= 0);
|
||||
return grad_accum_count_ < weights_.size() && tokens_[key] <= 0;
|
||||
}
|
||||
|
@ -544,11 +551,9 @@ void ParameterServer::ServerHandler::HandleInitWeights(DataPtr data, size_t size
|
|||
for (int i = 0; i < key_num; i++) {
|
||||
Key key = input.keys()[i];
|
||||
size_t data_len = input.len_size() != key_num ? input.values_size() / key_num : input.len()[i];
|
||||
MS_LOG(DEBUG) << "The data len:" << data_len;
|
||||
|
||||
if (!ps_->HasWeight(key)) {
|
||||
WeightPtr weight_ptr = std::make_shared<std::vector<float>>(data_ptr + pos, data_ptr + (pos + data_len));
|
||||
MS_LOG(DEBUG) << "The weight ptr:" << *weight_ptr;
|
||||
MS_EXCEPTION_IF_NULL(weight_ptr);
|
||||
ps_->InitWeight(key, weight_ptr);
|
||||
|
||||
|
@ -637,7 +642,7 @@ void ParameterServer::ServerHandler::HandleCheckReadyForPush(DataPtr data, size_
|
|||
input.ParseFromArray(data.get(), size);
|
||||
const Key &key = input.keys()[0];
|
||||
bool ready = ps_->ReadyForPush(key);
|
||||
MS_LOG(INFO) << "the ready is:" << ready;
|
||||
MS_LOG(INFO) << "The ready is:" << ready;
|
||||
KVMessage res_data;
|
||||
res_data.add_keys(key);
|
||||
res_data.add_values(ready);
|
||||
|
@ -671,7 +676,6 @@ void ParameterServer::ServerHandler::HandleEmbeddingLookup(DataPtr data, size_t
|
|||
EmbeddingTableLookup input;
|
||||
input.ParseFromArray(data.get(), size);
|
||||
const Key &key = input.key();
|
||||
MS_LOG(DEBUG) << "The key is:" << key;
|
||||
|
||||
KVMessage res_data;
|
||||
std::vector<Key> keys = {input.keys().begin(), input.keys().end()};
|
||||
|
@ -701,6 +705,5 @@ void ParameterServer::ServerHandler::HandleFinalize(DataPtr data, size_t size, V
|
|||
MS_EXCEPTION_IF_NULL(res);
|
||||
ps_->Finalize();
|
||||
}
|
||||
} // namespace internal
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
File diff suppressed because it is too large
Load Diff
|
@ -145,7 +145,6 @@ const size_t &PsCacheManager::QueryHashTableSize(const std::string ¶m_name)
|
|||
void PsCacheManager::Initialize() {
|
||||
MS_LOG(INFO) << "PS cache initialize.";
|
||||
if (!worker.running()) {
|
||||
Util::SetInternalEnvVar();
|
||||
worker.Run();
|
||||
}
|
||||
embedding_device_cache_ = std::make_shared<EmbeddingDeviceCache>(batch_elements_, vocab_cache_size_);
|
||||
|
@ -177,22 +176,19 @@ void PsCacheManager::InitParameterServer() {
|
|||
for (const auto &item : hash_tables_) {
|
||||
const auto ¶m_name = item.first;
|
||||
size_t key = worker.SetParamKey(param_name);
|
||||
std::vector<size_t> keys{key, key, key, key, key, key};
|
||||
std::vector<float> values{
|
||||
SizeToFloat(item.second.vocab_size), SizeToFloat(item.second.embedding_size), 1, 1, 1, 1, 1};
|
||||
std::vector<int64_t> lens{2, 2, 3};
|
||||
const auto &hash_table_info = item.second;
|
||||
const auto ¶m_init_info = hash_table_info.param_init_info_;
|
||||
if (param_init_info.param_type_ == kWeight) {
|
||||
lens.push_back(1);
|
||||
} else if (param_init_info.param_type_ == kAccumulation) {
|
||||
lens.push_back(2);
|
||||
}
|
||||
values.push_back(param_init_info.init_val_);
|
||||
lens.push_back(param_init_info.global_seed_);
|
||||
lens.push_back(param_init_info.op_seed_);
|
||||
|
||||
std::vector<size_t> input_shape = {item.second.vocab_size, item.second.embedding_size};
|
||||
std::vector<size_t> indices_shape = {1, 1};
|
||||
std::vector<size_t> output_shape = {1, 1, 1};
|
||||
ParamInitInfoMessage info;
|
||||
info.set_param_type(param_init_info.param_type_);
|
||||
info.set_init_val(param_init_info.init_val_);
|
||||
info.set_global_seed(param_init_info.global_seed_);
|
||||
info.set_op_seed(param_init_info.op_seed_);
|
||||
// if worker role
|
||||
worker.InitPSEmbeddingTable(keys, values, lens);
|
||||
worker.InitPSEmbeddingTable(key, input_shape, indices_shape, output_shape, info);
|
||||
}
|
||||
|
||||
finish_init_parameter_server_ = true;
|
||||
|
@ -245,7 +241,7 @@ void PsCacheManager::AllocMemForHashTable() {
|
|||
}
|
||||
|
||||
void PsCacheManager::SetLocalIdRank() {
|
||||
auto worker_num = ::ps::NumWorkers();
|
||||
auto worker_num = PSContext::instance()->initial_worker_num();
|
||||
auto local_shard_size = FloatToInt(std::ceil(SizeToFloat(vocab_size_) / worker_num));
|
||||
vocab_cache_size_diff_ = local_shard_size - SizeToInt(vocab_cache_size_);
|
||||
emb_table_slice_bounds_.first = local_shard_size * rank_id_;
|
||||
|
@ -829,8 +825,8 @@ bool PsCacheManager::HashSwapHostToServer(size_t key, const HashTableInfo &hash_
|
|||
if (swap_indices_size == 0) {
|
||||
return true;
|
||||
}
|
||||
::ps::SArray<int> lookup_ids(swap_indices_size, 0);
|
||||
::ps::SArray<float> swap_out_data;
|
||||
std::vector<int> lookup_ids(swap_indices_size, 0);
|
||||
std::vector<float> swap_out_data;
|
||||
auto embedding_size = hash_info.embedding_size;
|
||||
swap_out_data.resize(swap_indices_size * embedding_size);
|
||||
auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get());
|
||||
|
@ -857,22 +853,21 @@ bool PsCacheManager::HashSwapServerToHost(size_t key, const HashTableInfo &hash_
|
|||
}
|
||||
auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get());
|
||||
auto embedding_size = hash_info.embedding_size;
|
||||
::ps::SArray<int> lengths{swap_indices_size};
|
||||
::ps::SArray<float> lookup_result(swap_indices_size * embedding_size, 0);
|
||||
::ps::SArray<int> lookup_ids(swap_indices_size, 0);
|
||||
std::vector<float> lookup_result(swap_indices_size * embedding_size, 0);
|
||||
std::vector<int> lookup_ids(swap_indices_size, 0);
|
||||
auto copy_len = swap_indices_size * sizeof(int);
|
||||
auto ret = memcpy_s(lookup_ids.data(), copy_len, server_to_host_ids, copy_len);
|
||||
if (ret != EOK) {
|
||||
MS_LOG(ERROR) << "Lookup id memcpy failed.";
|
||||
return false;
|
||||
}
|
||||
worker.DoPSEmbeddingLookup({key}, lookup_ids, lengths, &lookup_result, mindspore::ps::kEmbeddingLookupCmd);
|
||||
worker.DoPSEmbeddingLookup(key, lookup_ids, &lookup_result, mindspore::ps::kEmbeddingLookupCmd);
|
||||
RETURN_IF_FALSE(InsertHostHashTable(embedding_size, IntToSize(swap_indices_size), server_to_host_index,
|
||||
lookup_result.data(), host_hash_table_addr));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PsCacheManager::HashSwapDeviceOut(int *swap_out_index, ::ps::SArray<float> *swap_out_data,
|
||||
bool PsCacheManager::HashSwapDeviceOut(int *swap_out_index, std::vector<float> *swap_out_data,
|
||||
const HashTableInfo &hash_info) {
|
||||
MS_ERROR_IF_NULL(swap_out_index);
|
||||
MS_ERROR_IF_NULL(swap_out_data);
|
||||
|
@ -912,16 +907,15 @@ bool PsCacheManager::HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, cons
|
|||
auto cache_vocab_size = hash_info.cache_vocab_size;
|
||||
auto embedding_size = hash_info.embedding_size;
|
||||
// Get id embs by swap_in_ids in host(Pipeline with hash swap-out in device).
|
||||
::ps::SArray<int> lengths{swap_in_ids_size};
|
||||
::ps::SArray<float> lookup_result(swap_in_ids_size * embedding_size, 0);
|
||||
::ps::SArray<int> lookup_ids(swap_in_ids_size, 0);
|
||||
std::vector<float> lookup_result(swap_in_ids_size * embedding_size, 0);
|
||||
std::vector<int> lookup_ids(swap_in_ids_size, 0);
|
||||
auto copy_len = swap_in_ids_size * sizeof(int);
|
||||
auto ret = memcpy_s(lookup_ids.data(), copy_len, swap_in_ids, copy_len);
|
||||
if (ret != EOK) {
|
||||
MS_LOG(ERROR) << "Lookup id memcpy failed.";
|
||||
return false;
|
||||
}
|
||||
worker.DoPSEmbeddingLookup({key}, lookup_ids, lengths, &lookup_result, mindspore::ps::kEmbeddingLookupCmd);
|
||||
worker.DoPSEmbeddingLookup(key, lookup_ids, &lookup_result, mindspore::ps::kEmbeddingLookupCmd);
|
||||
// Hash swap-in in device.
|
||||
RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice(
|
||||
embedding_device_cache_->hash_swap_value_addr_, lookup_result.data(),
|
||||
|
@ -934,7 +928,7 @@ bool PsCacheManager::HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, cons
|
|||
return true;
|
||||
}
|
||||
|
||||
bool PsCacheManager::UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_data, int *swap_out_ids, size_t key) {
|
||||
bool PsCacheManager::UpdataEmbeddingTable(const std::vector<float> &swap_out_data, int *swap_out_ids, size_t key) {
|
||||
MS_ERROR_IF_NULL(embedding_device_cache_);
|
||||
MS_ERROR_IF_NULL(embedding_device_cache_->cache_);
|
||||
MS_ERROR_IF_NULL(swap_out_ids);
|
||||
|
@ -942,7 +936,7 @@ bool PsCacheManager::UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_da
|
|||
if (swap_out_ids_size == 0) {
|
||||
return true;
|
||||
}
|
||||
::ps::SArray<int> lookup_ids(swap_out_ids_size, 0);
|
||||
std::vector<int> lookup_ids(swap_out_ids_size, 0);
|
||||
auto copy_len = swap_out_ids_size * sizeof(int);
|
||||
auto ret = memcpy_s(lookup_ids.data(), copy_len, swap_out_ids, copy_len);
|
||||
if (ret != EOK) {
|
||||
|
@ -994,8 +988,8 @@ bool PsCacheManager::SyncHostEmbeddingTable() {
|
|||
continue;
|
||||
}
|
||||
auto key = worker.GetParamKey(item.first);
|
||||
::ps::SArray<int> lookup_ids(swap_indices_lens, 0);
|
||||
::ps::SArray<float> swap_out_data;
|
||||
std::vector<int> lookup_ids(swap_indices_lens, 0);
|
||||
std::vector<float> swap_out_data;
|
||||
auto embedding_size = hash_info.embedding_size;
|
||||
swap_out_data.resize(swap_indices_lens * embedding_size);
|
||||
auto host_hash_table_addr = hash_info.host_address.get();
|
||||
|
@ -1038,8 +1032,8 @@ bool PsCacheManager::SyncDeviceEmbeddingTable() {
|
|||
continue;
|
||||
}
|
||||
auto key = worker.GetParamKey(item.first);
|
||||
::ps::SArray<int> lookup_ids(swap_indices_lens, 0);
|
||||
::ps::SArray<float> swap_out_data;
|
||||
std::vector<int> lookup_ids(swap_indices_lens, 0);
|
||||
std::vector<float> swap_out_data;
|
||||
auto embedding_size = hash_info.embedding_size;
|
||||
swap_out_data.resize(swap_indices_lens * embedding_size);
|
||||
std::unique_ptr<float[]> device_hash_table_addr_tmp =
|
||||
|
|
|
@ -29,9 +29,9 @@
|
|||
#include "backend/kernel_compiler/kernel.h"
|
||||
#include "utils/shape_utils.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "ps/ps.h"
|
||||
#include "ps/common.h"
|
||||
#include "ps/constants.h"
|
||||
#include "ps/worker.h"
|
||||
#include "ps/ps_context.h"
|
||||
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
|
||||
#include "ps/ps_cache/embedding_hash_map.h"
|
||||
#include "ps/ps_cache/ps_cache_factory.h"
|
||||
|
@ -155,7 +155,7 @@ class PsCacheManager {
|
|||
bool ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device, int *hash_index);
|
||||
bool ParseHostDataHostToDevice(size_t id);
|
||||
bool ParseHostDataDeviceToHost();
|
||||
bool HashSwapDeviceOut(int *swap_out_index, ::ps::SArray<float> *swap_out_data, const HashTableInfo &hash_info);
|
||||
bool HashSwapDeviceOut(int *swap_out_index, std::vector<float> *swap_out_data, const HashTableInfo &hash_info);
|
||||
bool HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, const HashTableInfo &hash_info, size_t key);
|
||||
bool HashSwapHostToDevice(const HashTableInfo &hash_info);
|
||||
bool HashSwapDeviceToHost(const HashTableInfo &hash_info);
|
||||
|
@ -165,7 +165,7 @@ class PsCacheManager {
|
|||
float *hash_table_addr);
|
||||
bool LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr,
|
||||
const int *indices_addr, float *output_addr);
|
||||
bool UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_data, int *swap_out_ids, size_t key);
|
||||
bool UpdataEmbeddingTable(const std::vector<float> &swap_out_data, int *swap_out_ids, size_t key);
|
||||
void LookUpTableTask(size_t indices_lens, size_t outer_dim_size, size_t first_dim_size, const float *input_addr,
|
||||
const int *indices_addr, float *output_addr);
|
||||
bool CheckFinishInsertInitInfo() const;
|
||||
|
|
|
@ -48,10 +48,10 @@ void PSContext::SetPSEnable(bool enabled) {
|
|||
MS_LOG(WARNING) << "MS_ROLE is " << ms_role << ", which is invalid.";
|
||||
}
|
||||
|
||||
worker_num_ = std::strtol(common::GetEnv("MS_WORKER_NUM").c_str(), nullptr, 10);
|
||||
server_num_ = std::strtol(common::GetEnv("MS_SERVER_NUM").c_str(), nullptr, 10);
|
||||
scheduler_host_ = common::GetEnv("MS_SCHED_HOST");
|
||||
scheduler_port_ = std::strtol(common::GetEnv("MS_SCHED_PORT").c_str(), nullptr, 10);
|
||||
worker_num_ = std::strtol(common::GetEnv(kEnvWorkerNum).c_str(), nullptr, 10);
|
||||
server_num_ = std::strtol(common::GetEnv(kEnvPServerNum).c_str(), nullptr, 10);
|
||||
scheduler_host_ = common::GetEnv(kEnvSchedulerHost);
|
||||
scheduler_port_ = std::strtol(common::GetEnv(kEnvSchedulerPort).c_str(), nullptr, 10);
|
||||
} else {
|
||||
MS_LOG(INFO) << "PS mode is disabled.";
|
||||
is_worker_ = false;
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ps/constants.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
|
|
|
@ -15,13 +15,16 @@
|
|||
*/
|
||||
|
||||
#include "ps/scheduler.h"
|
||||
#include "ps/ps.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
void Scheduler::Run() {
|
||||
::ps::Start(0);
|
||||
::ps::Finalize(0, true);
|
||||
core::ClusterMetadata::instance()->Init(
|
||||
PSContext::instance()->initial_worker_num(), PSContext::instance()->initial_server_num(),
|
||||
PSContext::instance()->scheduler_host(), PSContext::instance()->scheduler_port());
|
||||
scheduler_node_.Start();
|
||||
scheduler_node_.Finish();
|
||||
scheduler_node_.Stop();
|
||||
exit(1);
|
||||
}
|
||||
} // namespace ps
|
||||
|
|
|
@ -16,6 +16,11 @@
|
|||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SCHEDULER_H_
|
||||
#define MINDSPORE_CCSRC_PS_SCHEDULER_H_
|
||||
|
||||
#include "ps/core/scheduler_node.h"
|
||||
#include "ps/util.h"
|
||||
#include "ps/ps_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
class Scheduler {
|
||||
|
@ -32,6 +37,7 @@ class Scheduler {
|
|||
~Scheduler() = default;
|
||||
Scheduler(const Scheduler &) = delete;
|
||||
Scheduler &operator=(const Scheduler &) = delete;
|
||||
core::SchedulerNode scheduler_node_;
|
||||
};
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
#include "ps/util.h"
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include "ps/common.h"
|
||||
#include "ps/constants.h"
|
||||
#include "ps/ps_context.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
|
@ -46,50 +46,10 @@ std::unordered_map<int64_t, std::string> Util::id_to_optimizer_nodes{
|
|||
{3, kSparseFtrlOp},
|
||||
};
|
||||
|
||||
bool Util::IsParamServerMode() { return PSContext::instance()->is_ps_mode(); }
|
||||
|
||||
bool Util::IsRoleOfWorker() { return PSContext::instance()->is_worker(); }
|
||||
|
||||
bool Util::IsRoleOfPServer() { return PSContext::instance()->is_server(); }
|
||||
|
||||
bool Util::IsRoleOfScheduler() { return PSContext::instance()->is_scheduler(); }
|
||||
|
||||
void Util::SetInternalEnvVar() {
|
||||
if (IsParamServerMode()) {
|
||||
auto comm_type = common::GetEnv(kEnvCommType);
|
||||
if (!comm_type.empty()) {
|
||||
(void)common::SetEnv(kDmlcCommType, comm_type.c_str());
|
||||
}
|
||||
auto interface = common::GetEnv(kEnvInterface);
|
||||
if (!interface.empty()) {
|
||||
(void)common::SetEnv(kDmlcInterface, interface.c_str());
|
||||
}
|
||||
auto server_num = common::GetEnv(kEnvPServerNum);
|
||||
if (!server_num.empty()) {
|
||||
(void)common::SetEnv(kDmlcPServerNum, server_num.c_str());
|
||||
}
|
||||
auto worker_num = common::GetEnv(kEnvWorkerNum);
|
||||
if (!worker_num.empty()) {
|
||||
(void)common::SetEnv(kDmlcWorkerNum, worker_num.c_str());
|
||||
}
|
||||
if (IsRoleOfScheduler()) {
|
||||
(void)common::SetEnv(kDmlcRole, kRoleOfScheduler);
|
||||
} else if (IsRoleOfPServer()) {
|
||||
(void)common::SetEnv(kDmlcRole, kRoleOfPServer);
|
||||
} else if (IsRoleOfWorker()) {
|
||||
(void)common::SetEnv(kDmlcRole, kRoleOfWorker);
|
||||
}
|
||||
auto scheduler_host = common::GetEnv(kEnvSchedulerHost);
|
||||
if (!scheduler_host.empty()) {
|
||||
(void)common::SetEnv(kDmlcSchedulerHost, scheduler_host.c_str());
|
||||
}
|
||||
auto scheduler_port = common::GetEnv(kEnvSchedulerPort);
|
||||
if (!scheduler_port.empty()) {
|
||||
(void)common::SetEnv(kDmlcSchedulerPort, scheduler_port.c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int64_t Util::optimizer_id(std::string name) {
|
||||
if (optimizer_to_ids.count(name) > 0) {
|
||||
return optimizer_to_ids[name];
|
||||
|
|
|
@ -37,11 +37,8 @@ struct ParamInitInfo {
|
|||
|
||||
class Util {
|
||||
public:
|
||||
static bool IsParamServerMode();
|
||||
static bool IsRoleOfWorker();
|
||||
static bool IsRoleOfPServer();
|
||||
static bool IsRoleOfScheduler();
|
||||
static void SetInternalEnvVar();
|
||||
static int64_t optimizer_id(std::string name);
|
||||
static std::string optimizer_name(int64_t id);
|
||||
static std::string optimizer_node_name(int64_t id);
|
||||
|
|
|
@ -14,11 +14,10 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/internal/worker.h"
|
||||
#include "ps/worker.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace internal {
|
||||
void Worker::Run() {
|
||||
std::lock_guard<std::mutex> lock(running_mutex_);
|
||||
core::ClusterMetadata::instance()->Init(
|
||||
|
@ -198,7 +197,8 @@ void Worker::AddEmbeddingTable(const Key &key, const size_t &row_count) {
|
|||
}
|
||||
|
||||
void Worker::InitPSEmbeddingTable(const size_t &key, const std::vector<size_t> &input_shape,
|
||||
const std::vector<size_t> &indices_shape, const std::vector<size_t> &output_shape) {
|
||||
const std::vector<size_t> &indices_shape, const std::vector<size_t> &output_shape,
|
||||
const ParamInitInfoMessage &info) {
|
||||
bool has_init = IsKeyInit(key);
|
||||
if (has_init) {
|
||||
MS_LOG(DEBUG) << "The key embedding table of key " << key << " is initialized.";
|
||||
|
@ -210,6 +210,7 @@ void Worker::InitPSEmbeddingTable(const size_t &key, const std::vector<size_t> &
|
|||
*embedding_table_meta.mutable_input_shape() = {input_shape.begin(), input_shape.end()};
|
||||
*embedding_table_meta.mutable_indices_shape() = {indices_shape.begin(), indices_shape.end()};
|
||||
*embedding_table_meta.mutable_output_shape() = {output_shape.begin(), output_shape.end()};
|
||||
*embedding_table_meta.mutable_info() = info;
|
||||
|
||||
std::string kv_data = embedding_table_meta.SerializeAsString();
|
||||
|
||||
|
@ -295,19 +296,18 @@ void Worker::DoPSEmbeddingLookup(const Key &key, const std::vector<int> &lookup_
|
|||
int64_t single_id_len = SizeToLong(lookup_result->size() / lookup_ids.size());
|
||||
std::unordered_map<Key, std::shared_ptr<std::pair<float *, int64_t>>> id_addr_map;
|
||||
std::shared_ptr<std::vector<float>> values = std::make_shared<std::vector<float>>();
|
||||
int64_t value_offset = 0;
|
||||
for (size_t i = 0; i < resp.size(); ++i) {
|
||||
KVMessage message;
|
||||
message.ParseFromArray(resp.at(i)->data(), resp.at(i)->size());
|
||||
int64_t offset = 0;
|
||||
values->clear();
|
||||
for (auto j = 0; j < message.values_size(); j++) {
|
||||
values->push_back(message.values(j));
|
||||
}
|
||||
MS_LOG(DEBUG) << "the embedding resp:" << values;
|
||||
MS_LOG(DEBUG) << "The embedding resp:" << values;
|
||||
for (auto k = 0; k < message.keys_size(); k++) {
|
||||
const Key &key = message.keys(k);
|
||||
float *addr = values->data() + offset;
|
||||
offset += single_id_len;
|
||||
float *addr = values->data() + value_offset;
|
||||
value_offset += single_id_len;
|
||||
id_addr_map[key] = std::make_shared<std::pair<float *, int64_t>>(std::make_pair(addr, single_id_len));
|
||||
}
|
||||
}
|
||||
|
@ -969,6 +969,5 @@ void Worker::SendForPull(int cmd, const KVMessage &send, const KVPartitioner &pa
|
|||
}
|
||||
}
|
||||
}
|
||||
} // namespace internal
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -25,24 +25,38 @@
|
|||
#include <functional>
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include "ps/ps.h"
|
||||
#include <mutex>
|
||||
#include <unordered_set>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "utils/log_adapter.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "ps/util.h"
|
||||
#include "ps/common.h"
|
||||
#include "ps/worker_proxy.h"
|
||||
#include "ps/constants.h"
|
||||
#include "utils/shape_utils.h"
|
||||
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
|
||||
#include "ps/core/worker_node.h"
|
||||
#include "ps/embedding_table_shard_metadata.h"
|
||||
#include "proto/comm.pb.h"
|
||||
#include "proto/ps.pb.h"
|
||||
#include "ps/ps_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
template <typename T>
|
||||
class Worker {
|
||||
public:
|
||||
static Worker &GetInstance() {
|
||||
static Worker instance;
|
||||
return instance;
|
||||
}
|
||||
using Callback = std::function<void()>;
|
||||
using PartitionEmbeddingMessages = std::vector<std::pair<bool, EmbeddingTableLookup>>;
|
||||
using PartitionKVMessages = std::vector<std::pair<bool, KVMessage>>;
|
||||
|
||||
using EmbeddingPartitioner = std::function<void(
|
||||
const EmbeddingTableLookup &send, PartitionEmbeddingMessages *partition, const std::map<int64_t, int64_t> &attrs)>;
|
||||
using KVPartitioner =
|
||||
std::function<void(const KVMessage &send, PartitionKVMessages *partition, const std::map<int64_t, int64_t> &attrs)>;
|
||||
|
||||
void Run();
|
||||
void Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs, const ShapeVector &sizes);
|
||||
|
@ -53,340 +67,89 @@ class Worker {
|
|||
bool GetParamInitInServer(const std::string ¶m_name);
|
||||
void SetKeyOptimId(size_t key, const std::string &optimizer_name);
|
||||
void SetOptimInputShapes(size_t key, const ShapeVector &shape);
|
||||
void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count);
|
||||
void InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vector<T> shapes, const ShapeVector &sizes);
|
||||
void AddEmbeddingTable(const Key &key, const size_t &row_count);
|
||||
void InitPSEmbeddingTable(const size_t &key, const std::vector<size_t> &input_shape,
|
||||
const std::vector<size_t> &indices_shape, const std::vector<size_t> &output_shape,
|
||||
const ParamInitInfoMessage &info);
|
||||
void InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor::TensorPtr &tensor);
|
||||
void DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
|
||||
const ::ps::SArray<int> &lens, ::ps::SArray<T> *lookup_result, int64_t cmd);
|
||||
void UpdateEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
|
||||
const ::ps::SArray<T> &vals);
|
||||
void DoPSEmbeddingLookup(const Key &key, const std::vector<int> &lookup_ids, std::vector<float> *lookup_result,
|
||||
int64_t cmd);
|
||||
void UpdateEmbeddingTable(const std::vector<Key> &keys, const std::vector<int> &lookup_ids,
|
||||
const std::vector<float> &vals);
|
||||
|
||||
bool running() { return running_; }
|
||||
void Finalize();
|
||||
|
||||
private:
|
||||
Worker() : kv_worker_(nullptr), running_(false), key_cnt_(0) {}
|
||||
Worker() : running_(false), key_cnt_(0) {}
|
||||
~Worker() = default;
|
||||
Worker(const Worker &) = delete;
|
||||
Worker &operator=(const Worker &) = delete;
|
||||
|
||||
void Initialize();
|
||||
bool IsKeyInit(const size_t key);
|
||||
void AddKeyToServerId(const Key &key);
|
||||
void AddKeyByHashMod(const Key &key);
|
||||
void InitPSOptimId(const size_t param_key);
|
||||
void InitPSOptimInputShapes(const size_t key);
|
||||
void InitPSParamData(const std::vector<size_t> &keys, void *origin_addr, size_t size);
|
||||
static void EmbeddingLookupIdSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &ranges,
|
||||
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced) {}
|
||||
bool IsReadyForPush(const Key &key);
|
||||
bool IsReadyForPull(const Key &key);
|
||||
void PrepareSparseGradient(const size_t begin, const size_t end, const std::unordered_set<int> &distinct_ids,
|
||||
const std::vector<std::pair<int, float *>> &indice_to_grads, const int *all_indice,
|
||||
const size_t segment_size, float *gradient, int *indices);
|
||||
void BuildSparseValue(const std::vector<int> &lengths, const size_t grad_index, const size_t indice_index,
|
||||
const float *original_data, const float *grads, int *indices, std::vector<float> *reduced_data);
|
||||
|
||||
std::shared_ptr<WorkerProxy<T>> kv_worker_;
|
||||
void PushData(const std::vector<Key> &keys, const std::vector<float> &vals, const std::vector<int> &lens = {},
|
||||
int command = 0, int64_t priority = 0);
|
||||
void PushSparseData(const std::vector<Key> &keys, const std::vector<float> &vals, const std::vector<int> &lens,
|
||||
size_t grad_index, size_t indice_index, size_t first_dim_size, size_t outer_dim_size);
|
||||
void PullData(const std::vector<Key> &keys, std::vector<float> *vals, std::vector<int> *lens = nullptr, int cmd = 0,
|
||||
int64_t priority = 0);
|
||||
|
||||
void LookupIdPartitioner(const EmbeddingTableLookup &send, PartitionEmbeddingMessages *partition,
|
||||
const std::map<int64_t, int64_t> &attrs);
|
||||
|
||||
void SparsePartitioner(const KVMessage &send, PartitionKVMessages *partition,
|
||||
const std::map<int64_t, int64_t> &attrs);
|
||||
void RoundRobinPartitioner(const KVMessage &send, PartitionKVMessages *partition,
|
||||
const std::map<int64_t, int64_t> &attrs);
|
||||
void WorkerInitEmbeddingPartitioner(const KVMessage &send, std::vector<std::pair<bool, KVMessage>> *partition,
|
||||
const std::map<int64_t, int64_t> &attrs);
|
||||
void UpdateEmbeddingPartitioner(const KVMessage &send, PartitionKVMessages *partition,
|
||||
const std::map<int64_t, int64_t> &attrs);
|
||||
void BroadcastPartitioner(const KVMessage &send, PartitionKVMessages *partition,
|
||||
const std::map<int64_t, int64_t> &attrs);
|
||||
void SendForPush(int cmd, const KVMessage &send, const KVPartitioner &partitioner,
|
||||
const std::map<int64_t, int64_t> &attrs);
|
||||
void SendForPull(int cmd, const KVMessage &send, const KVPartitioner &partitioner,
|
||||
const std::map<int64_t, int64_t> &attrs, std::vector<float> *vals, std::vector<int> *lens);
|
||||
|
||||
int64_t server_num_;
|
||||
bool running_;
|
||||
std::mutex running_mutex_;
|
||||
size_t key_cnt_;
|
||||
std::map<std::string, size_t> param_to_key_;
|
||||
std::map<size_t, bool> init_keys_;
|
||||
std::map<size_t, int64_t> key_to_optimId_;
|
||||
std::map<size_t, std::vector<ShapeVector>> key_to_optim_shapes_;
|
||||
std::map<std::string, bool> param_to_init_in_server_;
|
||||
core::WorkerNode worker_node_;
|
||||
|
||||
EmbeddingPartitioner lookup_partitioner_;
|
||||
KVPartitioner sparse_partitioner_;
|
||||
KVPartitioner round_robin_partitioner_;
|
||||
KVPartitioner worker_init_embedding_partitioner_;
|
||||
KVPartitioner update_embedding_partitioner_;
|
||||
KVPartitioner broadcast_partitioner_;
|
||||
std::unordered_map<Key, int64_t> key_to_server_id_;
|
||||
std::unordered_map<Key, size_t> embedding_row_cnt_;
|
||||
|
||||
std::unordered_map<Key, std::shared_ptr<std::vector<EmbeddingTableShardMetadata>>> embedding_table_ranges_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void Worker<T>::Run() {
|
||||
if (running_) {
|
||||
MS_LOG(INFO) << "'Worker is already running.";
|
||||
return;
|
||||
}
|
||||
MS_LOG(INFO) << "Worker starts connecting to scheduler and server...";
|
||||
::ps::Start(0);
|
||||
MS_LOG(INFO) << "Worker connected successfully.";
|
||||
if (!::ps::IsWorker()) {
|
||||
MS_LOG(EXCEPTION) << "The role is not worker.";
|
||||
}
|
||||
kv_worker_ = std::make_shared<WorkerProxy<T>>(0, 0, 1, 2);
|
||||
running_ = true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Worker<T>::Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs, const ShapeVector &sizes) {
|
||||
if (keys.size() == 0) {
|
||||
MS_LOG(EXCEPTION) << "key size should be greater than zero";
|
||||
}
|
||||
if (key_to_optimId_.count(keys[0]) == 0) {
|
||||
MS_LOG(EXCEPTION) << "no optim id found for key" << keys[0];
|
||||
}
|
||||
Key key = keys[0];
|
||||
int64_t optim_id = key_to_optimId_[key];
|
||||
bool is_sparse = false;
|
||||
if (optim_id == 1 || optim_id == 2 || optim_id == 3) {
|
||||
is_sparse = true;
|
||||
}
|
||||
int64_t grad_index = -1;
|
||||
int64_t indice_index = -1;
|
||||
|
||||
// Sparse adam gradient
|
||||
if (optim_id == 1 || optim_id == 2) {
|
||||
grad_index = 6;
|
||||
indice_index = 7;
|
||||
|
||||
// Sparse ftrl gradient
|
||||
} else if (optim_id == 3) {
|
||||
grad_index = 0;
|
||||
indice_index = 1;
|
||||
}
|
||||
|
||||
size_t total_size = std::accumulate(sizes.begin(), sizes.end(), 0, std::plus<int64_t>());
|
||||
::ps::SArray<T> total_buffer(total_size, 0);
|
||||
size_t offset = 0;
|
||||
size_t dst_size = 0;
|
||||
size_t src_size = 0;
|
||||
for (size_t i = 0; i < sizes.size(); i++) {
|
||||
void *dst_data = total_buffer.data() + offset / sizeof(T);
|
||||
void *src_data = reinterpret_cast<void *>(addrs[i]);
|
||||
MS_EXCEPTION_IF_NULL(dst_data);
|
||||
MS_EXCEPTION_IF_NULL(src_data);
|
||||
dst_size = sizes[i] * sizeof(T);
|
||||
src_size = sizes[i] * sizeof(T);
|
||||
auto ret = memcpy_s(dst_data, dst_size, src_data, src_size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
|
||||
return;
|
||||
}
|
||||
offset += sizes[i] * sizeof(T);
|
||||
}
|
||||
|
||||
while (!kv_worker_->IsReadyForPush(keys[0])) {
|
||||
continue;
|
||||
}
|
||||
std::vector<int> sizes_int;
|
||||
(void)std::transform(sizes.begin(), sizes.end(), std::back_inserter(sizes_int),
|
||||
[](const int64_t &value) { return static_cast<int>(value); });
|
||||
if (!is_sparse) {
|
||||
kv_worker_->PushData(::ps::SArray<::ps::Key>(keys), total_buffer, ::ps::SArray<int>(sizes_int));
|
||||
} else {
|
||||
std::vector<int64_t> &var_shape = key_to_optim_shapes_[key][0];
|
||||
int64_t first_dim_size = var_shape[0];
|
||||
int64_t outer_dim_size = std::accumulate(var_shape.begin() + 1, var_shape.end(), 1, std::multiplies<int64_t>());
|
||||
kv_worker_->PushSparseData(::ps::SArray<::ps::Key>(keys), total_buffer, ::ps::SArray<int>(sizes_int), grad_index,
|
||||
indice_index, first_dim_size, outer_dim_size);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Worker<T>::Pull(const size_t key, void *dev_addr, const size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(dev_addr);
|
||||
::ps::SArray<T> variables(size / sizeof(T), 0);
|
||||
while (!kv_worker_->IsReadyForPull(key)) {
|
||||
continue;
|
||||
}
|
||||
kv_worker_->PullData({key}, &variables);
|
||||
size_t dst_size = size;
|
||||
size_t src_size = size;
|
||||
auto ret = memcpy_s(dev_addr, dst_size, variables.data(), src_size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Worker<T>::DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
|
||||
const ::ps::SArray<int> &lens, ::ps::SArray<T> *lookup_result, int64_t cmd) {
|
||||
MS_EXCEPTION_IF_NULL(lookup_result);
|
||||
kv_worker_->EmbeddingLookup(keys, lookup_ids, lens, lookup_result, cmd);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Worker<T>::UpdateEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
|
||||
const ::ps::SArray<T> &vals) {
|
||||
kv_worker_->UpdateEmbeddingTable(keys, lookup_ids, vals);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Worker<T>::Finalize() {
|
||||
if (running_) {
|
||||
MS_LOG(INFO) << "Worker starts finalizing...";
|
||||
kv_worker_->Finalize();
|
||||
kv_worker_.reset();
|
||||
running_ = false;
|
||||
MS_LOG(INFO) << "Worker finalized successfully.";
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Worker<T>::InitPSParamData(const std::vector<size_t> &keys, void *origin_addr, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(origin_addr);
|
||||
::ps::SArray<T> addr(reinterpret_cast<T *>(origin_addr), size / sizeof(T));
|
||||
::ps::SArray<::ps::Key> key(keys);
|
||||
::ps::SArray<int> lens;
|
||||
lens.push_back(addr.size());
|
||||
kv_worker_->PushData(key, addr, lens, kInitWeightsCmd);
|
||||
init_keys_[key[0]] = true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Worker<T>::SetOptimInputShapes(size_t key, const ShapeVector &shape) {
|
||||
if (key_to_optim_shapes_.find(key) == key_to_optim_shapes_.end()) {
|
||||
key_to_optim_shapes_[key] = {shape};
|
||||
} else {
|
||||
key_to_optim_shapes_[key].push_back(shape);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Worker<T>::InitPSOptimInputShapes(const size_t key) {
|
||||
::ps::SArray<::ps::Key> keys;
|
||||
::ps::SArray<int> shape_len;
|
||||
::ps::SArray<T> all_shape;
|
||||
std::vector<ShapeVector> shapes = key_to_optim_shapes_[key];
|
||||
for (auto shape : shapes) {
|
||||
keys.push_back(key);
|
||||
if (shape.size() == 0) {
|
||||
shape_len.push_back(1);
|
||||
all_shape.push_back(1);
|
||||
} else {
|
||||
shape_len.push_back(SizeToLong(shape.size()));
|
||||
for (auto dim : shape) {
|
||||
all_shape.push_back(static_cast<T>(dim));
|
||||
}
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "keys:" << keys;
|
||||
MS_LOG(INFO) << "shape_len:" << shape_len;
|
||||
MS_LOG(INFO) << "all_shape:" << all_shape;
|
||||
if (!init_keys_[key]) {
|
||||
init_keys_[key] = true;
|
||||
}
|
||||
kv_worker_->PushData(keys, all_shape, shape_len, kInitOptimInputsShapeCmd);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool Worker<T>::IsKeyInit(const size_t key) {
|
||||
if (init_keys_.find(key) == init_keys_.end() || !init_keys_[key]) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
size_t Worker<T>::SetParamKey(const std::string ¶m_name) {
|
||||
size_t key = UINT64_MAX;
|
||||
if (param_to_key_.count(param_name)) {
|
||||
key = param_to_key_[param_name];
|
||||
MS_LOG(INFO) << param_name << " key is already set: key value is " << key;
|
||||
} else {
|
||||
key = key_cnt_++;
|
||||
param_to_key_[param_name] = key;
|
||||
MS_LOG(INFO) << "Set key " << key << " for parameter " << param_name;
|
||||
}
|
||||
return key;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Worker<T>::SetParamInitInServer(const std::string ¶m_name, bool init_in_server) {
|
||||
MS_LOG(INFO) << "Set parameter " << param_name << " init_in_server:" << init_in_server;
|
||||
param_to_init_in_server_[param_name] = init_in_server;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool Worker<T>::GetParamInitInServer(const std::string ¶m_name) {
|
||||
if (param_to_init_in_server_.count(param_name) == 0) {
|
||||
return false;
|
||||
}
|
||||
return param_to_init_in_server_[param_name];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
size_t Worker<T>::GetParamKey(const std::string ¶m_name) {
|
||||
size_t key = kInvalidKey;
|
||||
if (param_to_key_.find(param_name) != param_to_key_.end()) {
|
||||
key = param_to_key_[param_name];
|
||||
MS_LOG(DEBUG) << "Get key of parameter " << param_name << " key is " << key;
|
||||
}
|
||||
return key;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Worker<T>::SetKeyOptimId(size_t key, const std::string &optimizer_name) {
|
||||
key_to_optimId_[key] = Util::optimizer_id(optimizer_name);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Worker<T>::InitPSOptimId(const size_t param_key) {
|
||||
if (key_to_optimId_.count(param_key) == 0) {
|
||||
MS_LOG(EXCEPTION) << "Can't find optimizer id of parameter key " << param_key;
|
||||
}
|
||||
int64_t optim_id = key_to_optimId_[param_key];
|
||||
|
||||
::ps::SArray<::ps::Key> keys = {param_key};
|
||||
::ps::SArray<T> optim_id_vals = {static_cast<T>(optim_id)};
|
||||
::ps::SArray<int> optim_id_lens = {optim_id_vals.size()};
|
||||
kv_worker_->PushData(keys, optim_id_vals, optim_id_lens, kInitWeightToOptimIdCmd);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Worker<T>::InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vector<T> shapes, const ShapeVector &sizes) {
|
||||
bool has_init = IsKeyInit(keys[0]);
|
||||
if (has_init) {
|
||||
MS_LOG(DEBUG) << "The key embedding table of key " << keys[0] << " is initialized.";
|
||||
return;
|
||||
}
|
||||
::ps::SArray<T> shapes_val;
|
||||
for (auto dim : shapes) {
|
||||
shapes_val.push_back(dim);
|
||||
}
|
||||
std::vector<int> sizes_int;
|
||||
(void)std::transform(sizes.begin(), sizes.end(), std::back_inserter(sizes_int),
|
||||
[](const int64_t &value) { return static_cast<int>(value); });
|
||||
kv_worker_->Wait(
|
||||
kv_worker_->InitEmbeddingTable(::ps::SArray<::ps::Key>(keys), shapes_val, ::ps::SArray<int>(sizes_int)));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Worker<T>::InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor::TensorPtr &tensor) {
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
auto pk_node = input_node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(pk_node);
|
||||
const std::string ¶m_name = pk_node->fullname_with_scope();
|
||||
void *param_data = tensor->data_c();
|
||||
size_t param_size = LongToSize(tensor->data().nbytes());
|
||||
|
||||
size_t param_key = GetParamKey(param_name);
|
||||
if (param_key == kInvalidKey) {
|
||||
MS_LOG(DEBUG) << "Parameter " << param_name << " has no key assigned.";
|
||||
return;
|
||||
}
|
||||
bool init_in_server = false;
|
||||
auto param_info_ptr = pk_node->param_info();
|
||||
if (param_info_ptr != nullptr && param_info_ptr->init_in_server()) {
|
||||
init_in_server = true;
|
||||
}
|
||||
SetParamInitInServer(param_name, init_in_server);
|
||||
bool init = IsKeyInit(param_key);
|
||||
if (!init) {
|
||||
MS_LOG(INFO) << "Init paramter and optimizer in parameter server side for " << param_name
|
||||
<< ", whether init in server: " << init_in_server;
|
||||
kv_worker_->AddKeyToServerId(param_key);
|
||||
if (!PsDataPrefetch::GetInstance().cache_enable()) {
|
||||
if (!init_in_server) {
|
||||
if (param_size > INT_MAX) {
|
||||
MS_LOG(EXCEPTION) << "PS mode max weight size is " << INT_MAX << ", " << param_name << " size is "
|
||||
<< param_size;
|
||||
}
|
||||
InitPSParamData({param_key}, param_data, param_size);
|
||||
}
|
||||
InitPSOptimId(param_key);
|
||||
InitPSOptimInputShapes(param_key);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Worker<T>::AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count) {
|
||||
bool has_init = IsKeyInit(key);
|
||||
if (has_init) {
|
||||
return;
|
||||
}
|
||||
kv_worker_->AddEmbeddingTable(key, row_count);
|
||||
}
|
||||
|
||||
static Worker<float> &worker = Worker<float>::GetInstance();
|
||||
static Worker &worker = Worker::GetInstance();
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_WORKER_H_
|
||||
|
|
|
@ -1,873 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_WORKER_PROXY_H_
|
||||
#define MINDSPORE_CCSRC_PS_WORKER_PROXY_H_
|
||||
|
||||
#include <map>
|
||||
#include <numeric>
|
||||
#include <functional>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "ps/ps.h"
|
||||
#include "ps/util.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "ps/ps_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
template <typename T>
|
||||
class WorkerProxy : public ::ps::KVWorker<T> {
|
||||
public:
|
||||
using Worker = ::ps::KVWorker<T>;
|
||||
using Callback = std::function<void()>;
|
||||
using SlicedKVs = std::vector<std::pair<bool, ::ps::KVPairs<T>>>;
|
||||
using Slicer = std::function<void(int64_t ts, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &ranges,
|
||||
SlicedKVs *sliced, const std::map<int64_t, int64_t> &attrs)>;
|
||||
using ::ps::SimpleApp::obj_;
|
||||
explicit WorkerProxy(int64_t app_id, int64_t customer_id, int64_t lookup_customer_id, int64_t general_customer_id)
|
||||
: Worker(app_id, customer_id) {
|
||||
server_num_ = ::ps::NumServers();
|
||||
MS_LOG(INFO) << "Server num:" << server_num_;
|
||||
PSContext::instance()->SetPSRankId(::ps::MyRank());
|
||||
using std::placeholders::_1;
|
||||
using std::placeholders::_2;
|
||||
using std::placeholders::_3;
|
||||
using std::placeholders::_4;
|
||||
using std::placeholders::_5;
|
||||
lookup_customer_ = std::unique_ptr<::ps::Customer>(
|
||||
new ::ps::Customer(app_id, lookup_customer_id, std::bind(&WorkerProxy<T>::ProcessLookupResult, this, _1)));
|
||||
general_customer_ = std::unique_ptr<::ps::Customer>(
|
||||
new ::ps::Customer(app_id, general_customer_id, std::bind(&WorkerProxy<T>::ProcessResponse, this, _1)));
|
||||
lookup_slicer_ = std::bind(&WorkerProxy<T>::LookupIdSlicer, this, _1, _2, _3, _4, _5);
|
||||
sparse_slicer_ = std::bind(&WorkerProxy<T>::SparseSlicer, this, _1, _2, _3, _4, _5);
|
||||
broadcast_slicer_ = std::bind(&WorkerProxy<T>::BroadcastSlicer, this, _1, _2, _3, _4, _5);
|
||||
round_robin_slicer_ = std::bind(&WorkerProxy<T>::RoundRobinSlicer, this, _1, _2, _3, _4, _5);
|
||||
worker_init_embedding_slicer_ = std::bind(&WorkerProxy<T>::WorkerInitEmbeddingSlicer, this, _1, _2, _3, _4, _5);
|
||||
update_embedding_slicer_ = std::bind(&WorkerProxy<T>::UpdateEmbeddingSlicer, this, _1, _2, _3, _4, _5);
|
||||
}
|
||||
~WorkerProxy() override = default;
|
||||
|
||||
void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count);
|
||||
void AddKeyToServerId(const ::ps::Key &key);
|
||||
void EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
|
||||
const ::ps::SArray<int> &lens, ::ps::SArray<T> *outs, int64_t cmd = 0,
|
||||
const Callback &cb = nullptr, int64_t priority = 0);
|
||||
int64_t InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals,
|
||||
const ::ps::SArray<int> &lens = {}, const Callback &cb = nullptr, int64_t priority = 0);
|
||||
void UpdateEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
|
||||
const ::ps::SArray<T> &vals, const Callback &cb = nullptr, int64_t priority = 0);
|
||||
bool IsReadyForPush(const Key &key);
|
||||
bool IsReadyForPull(const Key &key);
|
||||
void PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals, const ::ps::SArray<int> &lens = {},
|
||||
int64_t cmd = 0, int64_t priority = 0);
|
||||
void PushSparseData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals, const ::ps::SArray<int> &lens,
|
||||
size_t grad_index, size_t indice_index, size_t first_dim_size, size_t outer_dim_size);
|
||||
void PullData(const ::ps::SArray<::ps::Key> &keys, ::ps::SArray<T> *vals, ::ps::SArray<int> *lens = nullptr,
|
||||
int64_t cmd = 0, int64_t priority = 0);
|
||||
void Finalize();
|
||||
|
||||
private:
|
||||
template <typename C>
|
||||
int64_t AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids, C *vals, int64_t cmd,
|
||||
const Callback &cb);
|
||||
int64_t AddGeneralRspCB(const ::ps::SArray<::ps::Key> &keys, ::ps::SArray<T> *vals, ::ps::SArray<int> *lens,
|
||||
int64_t cmd, const Callback &cb);
|
||||
void LookupIdSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
|
||||
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced, const std::map<int64_t, int64_t> &attrs);
|
||||
void SparseSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
|
||||
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced, const std::map<int64_t, int64_t> &attrs);
|
||||
void BroadcastSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
|
||||
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced, const std::map<int64_t, int64_t> &attrs);
|
||||
void RoundRobinSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
|
||||
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced,
|
||||
const std::map<int64_t, int64_t> &attrs);
|
||||
void WorkerInitEmbeddingSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
|
||||
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced,
|
||||
const std::map<int64_t, int64_t> &attrs);
|
||||
void UpdateEmbeddingSlicer(int timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
|
||||
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced,
|
||||
const std::map<int64_t, int64_t> &attrs);
|
||||
void ProcessLookupResult(const ::ps::Message &msg);
|
||||
void ProcessResponse(const ::ps::Message &msg);
|
||||
void Send(::ps::Customer *customer, int64_t timestamp, bool push, bool pull, int64_t cmd, const ::ps::KVPairs<T> &kvs,
|
||||
const Slicer &slicer, std::map<int64_t, int64_t> attrs = {});
|
||||
void AddKeyByHashMod(const ::ps::Key &key);
|
||||
|
||||
void PrepareSparseGradient(const size_t begin, const size_t end, const std::unordered_set<int> &distinct_ids,
|
||||
const std::vector<std::pair<int, T *>> &indice_to_grad, const int *all_indice,
|
||||
const size_t segment_size, T *gradient, int *indice);
|
||||
void BuildSparseValue(const ::ps::SArray<int> &lengths, const size_t grad_index, const size_t indice_index,
|
||||
const T *original_data, const T *grads, int *indices, ::ps::SArray<T> *reduced_data);
|
||||
|
||||
int64_t server_num_;
|
||||
std::unique_ptr<::ps::Customer> lookup_customer_;
|
||||
std::unique_ptr<::ps::Customer> general_customer_;
|
||||
std::unordered_map<::ps::Key, std::shared_ptr<std::vector<::ps::Range>>> embedding_table_ranges_;
|
||||
std::unordered_map<int64_t, std::vector<::ps::KVPairs<T>>> lookup_results_;
|
||||
std::unordered_map<int64_t, std::map<int64_t, ::ps::KVPairs<T>>> gathered_response_;
|
||||
std::mutex mutex_;
|
||||
Slicer lookup_slicer_;
|
||||
Slicer sparse_slicer_;
|
||||
Slicer broadcast_slicer_;
|
||||
Slicer round_robin_slicer_;
|
||||
Slicer worker_init_embedding_slicer_;
|
||||
Slicer update_embedding_slicer_;
|
||||
std::unordered_map<int64_t, Callback> lookup_callbacks_;
|
||||
std::unordered_map<int64_t, Callback> general_callbacks_;
|
||||
std::unordered_map<int64_t, int64_t> expected_result_count_;
|
||||
std::unordered_map<::ps::Key, int64_t> key_to_server_id_;
|
||||
std::unordered_map<::ps::Key, size_t> embedding_row_cnt_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void WorkerProxy<T>::AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count) {
|
||||
uint64_t begin = 0;
|
||||
uint64_t end = 0;
|
||||
for (int64_t i = 0; i < server_num_; i++) {
|
||||
int64_t local_row_cnt = Util::LocalShard(row_count, i, server_num_);
|
||||
if (i == 0) {
|
||||
end = local_row_cnt - 1;
|
||||
} else {
|
||||
begin = end + 1;
|
||||
end += local_row_cnt;
|
||||
}
|
||||
::ps::Range range(begin, end);
|
||||
if (embedding_table_ranges_.count(key) == 0) {
|
||||
embedding_table_ranges_[key] = std::make_shared<std::vector<::ps::Range>>();
|
||||
MS_EXCEPTION_IF_NULL(embedding_table_ranges_[key]);
|
||||
}
|
||||
embedding_table_ranges_[key]->push_back(range);
|
||||
}
|
||||
embedding_row_cnt_[key] = row_count;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void WorkerProxy<T>::AddKeyByHashMod(const ::ps::Key &key) {
|
||||
if (server_num_ == 0) {
|
||||
MS_LOG(EXCEPTION) << "Server number is invalid:0";
|
||||
}
|
||||
key_to_server_id_[key] = static_cast<int64_t>(key % server_num_);
|
||||
MS_LOG(INFO) << "The server id of key " << key << " is " << key_to_server_id_[key];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void WorkerProxy<T>::AddKeyToServerId(const ::ps::Key &key) {
|
||||
AddKeyByHashMod(key);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void WorkerProxy<T>::EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
|
||||
const ::ps::SArray<int> &lens, ::ps::SArray<T> *outs, int64_t cmd,
|
||||
const Callback &cb, int64_t priority) {
|
||||
int64_t ts = AddLookupCB(keys, lookup_ids, outs, cmd, cb);
|
||||
::ps::KVPairs<T> kvs;
|
||||
kvs.keys = keys;
|
||||
kvs.lens = lookup_ids;
|
||||
kvs.priority = priority;
|
||||
expected_result_count_[ts] = 0;
|
||||
Send(lookup_customer_.get(), ts, true, true, cmd, kvs, lookup_slicer_);
|
||||
int64_t expect_rt_count = expected_result_count_[ts];
|
||||
lookup_customer_->AddResponse(ts, server_num_ - expect_rt_count);
|
||||
lookup_customer_->WaitRequest(ts);
|
||||
expected_result_count_.erase(ts);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int64_t WorkerProxy<T>::InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals,
|
||||
const ::ps::SArray<int> &lens, const Callback &cb, int64_t priority) {
|
||||
int64_t ts = obj_->NewRequest(::ps::kServerGroup);
|
||||
::ps::KVPairs<T> kvs;
|
||||
kvs.keys = keys;
|
||||
kvs.vals = vals;
|
||||
kvs.lens = lens;
|
||||
kvs.priority = priority;
|
||||
Send(obj_, ts, true, false, kInitEmbeddingsCmd, kvs, broadcast_slicer_);
|
||||
return ts;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void WorkerProxy<T>::UpdateEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
|
||||
const ::ps::SArray<T> &vals, const Callback &cb, int64_t priority) {
|
||||
int ts = AddGeneralRspCB(keys, nullptr, nullptr, 0, nullptr);
|
||||
::ps::KVPairs<T> kvs;
|
||||
kvs.keys = keys;
|
||||
kvs.lens = lookup_ids;
|
||||
kvs.vals = vals;
|
||||
kvs.priority = priority;
|
||||
expected_result_count_[ts] = 0;
|
||||
Send(general_customer_.get(), ts, true, false, kUpdateEmbeddingsCmd, kvs, update_embedding_slicer_);
|
||||
if (expected_result_count_[ts] < server_num_) {
|
||||
general_customer_->AddResponse(ts, server_num_ - expected_result_count_[ts]);
|
||||
}
|
||||
general_customer_->WaitRequest(ts);
|
||||
expected_result_count_.erase(ts);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool WorkerProxy<T>::IsReadyForPush(const Key &key) {
|
||||
::ps::SArray<T> result(1, 0);
|
||||
PullData({key}, &result, nullptr, kCheckReadyForPushCmd);
|
||||
if (result[0] > 0) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool WorkerProxy<T>::IsReadyForPull(const Key &key) {
|
||||
::ps::SArray<T> result(1, 0);
|
||||
PullData({key}, &result, nullptr, kCheckReadyForPullCmd);
|
||||
if (result[0] > 0) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void WorkerProxy<T>::PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals,
|
||||
const ::ps::SArray<int> &lens, int64_t cmd, int64_t priority) {
|
||||
int64_t ts = AddGeneralRspCB(keys, nullptr, nullptr, cmd, nullptr);
|
||||
::ps::KVPairs<T> kvs;
|
||||
kvs.keys = keys;
|
||||
kvs.vals = vals;
|
||||
kvs.lens = lens;
|
||||
kvs.priority = priority;
|
||||
if (embedding_table_ranges_.count(keys[0])) {
|
||||
if (cmd == kInitWeightsCmd) {
|
||||
Send(general_customer_.get(), ts, true, false, cmd, kvs, worker_init_embedding_slicer_);
|
||||
} else {
|
||||
Send(general_customer_.get(), ts, true, false, cmd, kvs, broadcast_slicer_);
|
||||
}
|
||||
} else {
|
||||
Send(general_customer_.get(), ts, true, false, cmd, kvs, round_robin_slicer_);
|
||||
}
|
||||
if (expected_result_count_[ts] < server_num_) {
|
||||
general_customer_->AddResponse(ts, server_num_ - expected_result_count_[ts]);
|
||||
}
|
||||
general_customer_->WaitRequest(ts);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void WorkerProxy<T>::PushSparseData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals,
|
||||
const ::ps::SArray<int> &lens, size_t grad_index, size_t indice_index,
|
||||
size_t first_dim_size, size_t outer_dim_size) {
|
||||
int64_t ts = AddGeneralRspCB(keys, nullptr, nullptr, 0, nullptr);
|
||||
::ps::KVPairs<T> kvs;
|
||||
kvs.keys = keys;
|
||||
kvs.vals = vals;
|
||||
kvs.lens = lens;
|
||||
const int64_t cmd = 0;
|
||||
if (embedding_table_ranges_.count(keys[0])) {
|
||||
std::map<int64_t, int64_t> attrs{{0, grad_index}, {1, indice_index}, {2, first_dim_size}, {3, outer_dim_size}};
|
||||
Send(general_customer_.get(), ts, true, false, cmd, kvs, sparse_slicer_, attrs);
|
||||
} else {
|
||||
Send(general_customer_.get(), ts, true, false, cmd, kvs, round_robin_slicer_);
|
||||
}
|
||||
if (expected_result_count_[ts] < server_num_) {
|
||||
general_customer_->AddResponse(ts, server_num_ - expected_result_count_[ts]);
|
||||
}
|
||||
general_customer_->WaitRequest(ts);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void WorkerProxy<T>::PullData(const ::ps::SArray<::ps::Key> &keys, ::ps::SArray<T> *vals, ::ps::SArray<int> *lens,
|
||||
int64_t cmd, int64_t priority) {
|
||||
MS_EXCEPTION_IF_NULL(vals);
|
||||
int64_t ts = AddGeneralRspCB(keys, vals, lens, cmd, nullptr);
|
||||
::ps::KVPairs<T> kvs;
|
||||
kvs.keys = keys;
|
||||
kvs.priority = priority;
|
||||
if (embedding_table_ranges_.count(keys[0])) {
|
||||
Send(general_customer_.get(), ts, false, true, cmd, kvs, broadcast_slicer_);
|
||||
} else {
|
||||
Send(general_customer_.get(), ts, false, true, cmd, kvs, round_robin_slicer_);
|
||||
}
|
||||
if (expected_result_count_[ts] < server_num_) {
|
||||
general_customer_->AddResponse(ts, server_num_ - expected_result_count_[ts]);
|
||||
}
|
||||
general_customer_->WaitRequest(ts);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void WorkerProxy<T>::Finalize() {
|
||||
int64_t ts = obj_->NewRequest(::ps::kServerGroup);
|
||||
::ps::KVPairs<T> kvs;
|
||||
kvs.keys.push_back(0);
|
||||
kvs.vals.push_back(0.0f);
|
||||
Send(obj_, ts, true, false, kFinalizeCmd, kvs, broadcast_slicer_);
|
||||
obj_->WaitRequest(ts);
|
||||
::ps::Finalize(0, true);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename C>
|
||||
int64_t WorkerProxy<T>::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
|
||||
C *lookup_result, int64_t cmd, const Callback &cb) {
|
||||
MS_EXCEPTION_IF_NULL(lookup_result);
|
||||
int64_t ts = lookup_customer_->NewRequest(::ps::kServerGroup);
|
||||
const auto &callback = [this, ts, keys, lookup_ids, lookup_result, cb]() mutable {
|
||||
mutex_.lock();
|
||||
auto &kvs = lookup_results_[ts];
|
||||
mutex_.unlock();
|
||||
|
||||
if (lookup_ids.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Lookup id is empty.";
|
||||
}
|
||||
int64_t single_id_len = SizeToLong(lookup_result->size() / lookup_ids.size());
|
||||
std::unordered_map<Key, std::shared_ptr<std::pair<T *, int64_t>>> id_addr_map;
|
||||
for (const auto &s : kvs) {
|
||||
int64_t offset = 0;
|
||||
for (size_t i = 0; i < s.keys.size(); i++) {
|
||||
const Key &key = s.keys[i];
|
||||
T *addr = s.vals.data() + offset;
|
||||
offset += single_id_len;
|
||||
id_addr_map[key] = std::make_shared<std::pair<T *, int64_t>>(std::make_pair(addr, single_id_len));
|
||||
MS_EXCEPTION_IF_NULL(id_addr_map[key]);
|
||||
}
|
||||
}
|
||||
|
||||
T *result_addr = lookup_result->data();
|
||||
MS_EXCEPTION_IF_NULL(result_addr);
|
||||
int64_t offset = 0;
|
||||
size_t dst_size = 0;
|
||||
size_t src_size = 0;
|
||||
void *dst_data = nullptr;
|
||||
void *src_data = nullptr;
|
||||
for (size_t i = 0; i < lookup_ids.size(); i++) {
|
||||
if (id_addr_map.count(lookup_ids[i]) == 0) {
|
||||
offset += single_id_len;
|
||||
continue;
|
||||
}
|
||||
auto &pair = id_addr_map[static_cast<Key>(lookup_ids[i])];
|
||||
int64_t size = single_id_len * sizeof(T);
|
||||
dst_size = size;
|
||||
src_size = size;
|
||||
dst_data = result_addr + offset;
|
||||
src_data = pair->first;
|
||||
MS_EXCEPTION_IF_NULL(dst_data);
|
||||
MS_EXCEPTION_IF_NULL(src_data);
|
||||
auto ret = memcpy_s(dst_data, dst_size, src_data, src_size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
|
||||
return;
|
||||
}
|
||||
offset += single_id_len;
|
||||
}
|
||||
|
||||
mutex_.lock();
|
||||
lookup_results_.erase(ts);
|
||||
mutex_.unlock();
|
||||
if (cb) cb();
|
||||
};
|
||||
lookup_callbacks_[ts] = callback;
|
||||
return ts;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int64_t WorkerProxy<T>::AddGeneralRspCB(const ::ps::SArray<::ps::Key> &keys, ::ps::SArray<T> *vals,
|
||||
::ps::SArray<int> *lens, int64_t cmd, const Callback &cb) {
|
||||
int64_t ts = general_customer_->NewRequest(::ps::kServerGroup);
|
||||
const auto &callback = [this, ts, keys, vals, lens, cb]() mutable {
|
||||
mutex_.lock();
|
||||
std::map<int64_t, ::ps::KVPairs<T>> server_kvs = gathered_response_[ts];
|
||||
mutex_.unlock();
|
||||
|
||||
vals->clear();
|
||||
for (auto kvs : server_kvs) {
|
||||
for (auto val : kvs.second.vals) {
|
||||
vals->push_back(val);
|
||||
}
|
||||
if (lens) {
|
||||
for (auto len : kvs.second.lens) {
|
||||
lens->push_back(len);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mutex_.lock();
|
||||
gathered_response_.erase(ts);
|
||||
mutex_.unlock();
|
||||
if (cb) {
|
||||
cb();
|
||||
}
|
||||
};
|
||||
general_callbacks_[ts] = callback;
|
||||
return ts;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void WorkerProxy<T>::LookupIdSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
|
||||
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced,
|
||||
const std::map<int64_t, int64_t> &attrs) {
|
||||
MS_EXCEPTION_IF_NULL(sliced);
|
||||
int32_t *lookup_ids = send.lens.data();
|
||||
size_t id_size = send.lens.size();
|
||||
|
||||
const Key &key = send.keys[0];
|
||||
const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]);
|
||||
sliced->resize(ranges.size());
|
||||
|
||||
for (size_t i = 0; i < ranges.size(); i++) {
|
||||
const ::ps::Range &range = ranges[i];
|
||||
const auto &begin = range.begin();
|
||||
const auto &end = range.end();
|
||||
std::unordered_set<int64_t> unique_ids;
|
||||
auto &kvs = sliced->at(i).second;
|
||||
|
||||
kvs.keys.push_back(key);
|
||||
kvs.vals.push_back(0.0f);
|
||||
|
||||
for (size_t j = 0; j < id_size; j++) {
|
||||
auto lookup_id = static_cast<uint64_t>(lookup_ids[j]);
|
||||
// If lookup_id is out of range, like negative number, unique_ids will not contain it.
|
||||
// Servers always get lookup_ids in its embedding table range.
|
||||
if (lookup_id >= begin && lookup_id <= end) {
|
||||
unique_ids.insert(lookup_id);
|
||||
}
|
||||
}
|
||||
for (const auto &lookup_id : unique_ids) {
|
||||
kvs.keys.push_back(lookup_id);
|
||||
kvs.vals.push_back(0.0f);
|
||||
}
|
||||
|
||||
if (kvs.keys.size() <= 1) {
|
||||
sliced->at(i).first = false;
|
||||
} else {
|
||||
sliced->at(i).first = true;
|
||||
expected_result_count_[timestamp] += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void WorkerProxy<T>::SparseSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
|
||||
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced,
|
||||
const std::map<int64_t, int64_t> &attrs) {
|
||||
MS_EXCEPTION_IF_NULL(sliced);
|
||||
// Init variables
|
||||
T *data = send.vals.data();
|
||||
|
||||
if (attrs.count(0) == 0 || attrs.count(1) == 0 || attrs.count(2) == 0 || attrs.count(3) == 0) {
|
||||
MS_LOG(EXCEPTION) << "Invalid attrs keys";
|
||||
}
|
||||
auto iter = attrs.find(0);
|
||||
size_t grad_index = static_cast<size_t>(iter->second);
|
||||
iter = attrs.find(1);
|
||||
size_t indice_index = static_cast<size_t>(iter->second);
|
||||
iter = attrs.find(2);
|
||||
size_t first_dim_size = static_cast<size_t>(iter->second);
|
||||
iter = attrs.find(3);
|
||||
size_t outer_dim_size = static_cast<size_t>(iter->second);
|
||||
|
||||
int grad_size = send.lens[grad_index];
|
||||
int indice_size = send.lens[indice_index];
|
||||
int segment_size = grad_size / indice_size;
|
||||
|
||||
int64_t grad_offset = 0;
|
||||
int64_t indice_offset = 0;
|
||||
for (size_t i = 0; i < grad_index; i++) {
|
||||
grad_offset += send.lens[i];
|
||||
}
|
||||
for (size_t j = 0; j < indice_index; j++) {
|
||||
indice_offset += send.lens[j];
|
||||
}
|
||||
|
||||
T *grad_data = data + grad_offset;
|
||||
int *indice_data = reinterpret_cast<int *>(data) + indice_offset;
|
||||
|
||||
// Build the mappings of indice to gradient
|
||||
std::vector<std::pair<int, T *>> indice_to_grads;
|
||||
for (int i = 0; i < indice_size; i++) {
|
||||
int indice = indice_data[i];
|
||||
T *grad = grad_data + i * segment_size;
|
||||
indice_to_grads.push_back(std::make_pair(indice, grad));
|
||||
}
|
||||
|
||||
const Key &key = send.keys[0];
|
||||
const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]);
|
||||
sliced->resize(ranges.size());
|
||||
|
||||
// Construct reduced sparse data for each server
|
||||
for (size_t i = 0; i < ranges.size(); i++) {
|
||||
const ::ps::Range &range = ranges[i];
|
||||
const auto &begin = range.begin();
|
||||
const auto &end = range.end();
|
||||
auto &kvs = sliced->at(i).second;
|
||||
kvs.keys = send.keys;
|
||||
kvs.lens = send.lens;
|
||||
|
||||
// Prepare the sparse gradient and indice
|
||||
std::vector<int> indice_ids;
|
||||
std::unordered_set<int> distinct_ids;
|
||||
for (int j = 0; j < indice_size; j++) {
|
||||
size_t indice = static_cast<size_t>(indice_data[j]);
|
||||
if (indice >= begin && indice <= end) {
|
||||
indice_ids.push_back(indice);
|
||||
distinct_ids.insert(indice);
|
||||
}
|
||||
}
|
||||
size_t indices_size = indice_ids.size();
|
||||
if (indices_size > 0) {
|
||||
int slice_segment_size = indices_size * segment_size;
|
||||
std::vector<T> src_grad_data(slice_segment_size);
|
||||
std::vector<int> src_indice_data(indices_size);
|
||||
PrepareSparseGradient(begin, end, distinct_ids, indice_to_grads, indice_data, segment_size, src_grad_data.data(),
|
||||
src_indice_data.data());
|
||||
|
||||
// Reduce the sparse gradient and indice
|
||||
std::vector<T> new_grad(slice_segment_size);
|
||||
std::vector<int> new_indices(indices_size);
|
||||
mindspore::kernel::SparseGradient<int> unique_sparse_grad({new_grad.data(), new_indices.data(), indices_size});
|
||||
Util::ReduceSparseGradient(src_grad_data.data(), src_indice_data.data(), indices_size, segment_size,
|
||||
first_dim_size, outer_dim_size, &unique_sparse_grad);
|
||||
|
||||
// Update the length of reduce sparse gradient and indice
|
||||
::ps::SArray<int> reduced_lens;
|
||||
reduced_lens.CopyFrom(kvs.lens);
|
||||
reduced_lens[grad_index] = unique_sparse_grad.indices_size_ * segment_size;
|
||||
reduced_lens[indice_index] = unique_sparse_grad.indices_size_;
|
||||
|
||||
// Build the sparse value to be sent
|
||||
size_t total_size = std::accumulate(reduced_lens.begin(), reduced_lens.end(), 0, std::plus<int>());
|
||||
::ps::SArray<T> reduced_data(total_size, 0);
|
||||
BuildSparseValue(reduced_lens, grad_index, indice_index, data, unique_sparse_grad.value_,
|
||||
unique_sparse_grad.indices_, &reduced_data);
|
||||
|
||||
kvs.lens = reduced_lens;
|
||||
kvs.vals = reduced_data;
|
||||
}
|
||||
|
||||
if (indices_size <= 0) {
|
||||
::ps::SArray<T> no_keys;
|
||||
::ps::SArray<T> no_vals;
|
||||
::ps::SArray<T> no_lens;
|
||||
no_keys.push_back(key);
|
||||
no_vals.push_back(-100);
|
||||
kvs.vals = no_vals;
|
||||
kvs.lens = no_lens;
|
||||
}
|
||||
sliced->at(i).first = true;
|
||||
expected_result_count_[timestamp] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void WorkerProxy<T>::PrepareSparseGradient(const size_t begin, const size_t end,
|
||||
const std::unordered_set<int> &distinct_ids,
|
||||
const std::vector<std::pair<int, T *>> &indice_to_grads,
|
||||
const int *all_indice, const size_t segment_size, T *gradient,
|
||||
int *indices) {
|
||||
MS_EXCEPTION_IF_NULL(all_indice);
|
||||
MS_EXCEPTION_IF_NULL(gradient);
|
||||
MS_EXCEPTION_IF_NULL(indices);
|
||||
int64_t offset = 0;
|
||||
int64_t index = 0;
|
||||
size_t segment_data_size = segment_size * sizeof(T);
|
||||
size_t dst_size;
|
||||
size_t src_size;
|
||||
void *dst_data = nullptr;
|
||||
void *src_data = nullptr;
|
||||
for (auto &pair : indice_to_grads) {
|
||||
if (distinct_ids.count(pair.first) == 0) {
|
||||
continue;
|
||||
}
|
||||
indices[index++] = pair.first;
|
||||
|
||||
dst_size = segment_data_size;
|
||||
src_size = segment_data_size;
|
||||
dst_data = gradient + offset;
|
||||
src_data = pair.second;
|
||||
MS_EXCEPTION_IF_NULL(dst_data);
|
||||
MS_EXCEPTION_IF_NULL(src_data);
|
||||
auto ret = memcpy_s(gradient + offset, dst_size, pair.second, src_size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
|
||||
return;
|
||||
}
|
||||
offset += segment_size;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void WorkerProxy<T>::BuildSparseValue(const ::ps::SArray<int> &lengths, const size_t grad_index,
|
||||
const size_t indice_index, const T *original_data, const T *grads, int *indices,
|
||||
::ps::SArray<T> *reduced_data) {
|
||||
MS_EXCEPTION_IF_NULL(original_data);
|
||||
MS_EXCEPTION_IF_NULL(grads);
|
||||
MS_EXCEPTION_IF_NULL(indices);
|
||||
MS_EXCEPTION_IF_NULL(reduced_data);
|
||||
int64_t offset = 0;
|
||||
size_t dst_size = 0;
|
||||
size_t src_size = 0;
|
||||
void *dst_data = nullptr;
|
||||
void *src_data = nullptr;
|
||||
for (size_t i = 0; i < lengths.size(); i++) {
|
||||
if (i != grad_index && i != indice_index) {
|
||||
int data_size = lengths[i] * sizeof(T);
|
||||
dst_size = data_size;
|
||||
src_size = data_size;
|
||||
dst_data = reduced_data->data() + offset;
|
||||
src_data = const_cast<T *>(original_data) + offset;
|
||||
MS_EXCEPTION_IF_NULL(dst_data);
|
||||
MS_EXCEPTION_IF_NULL(src_data);
|
||||
auto ret = memcpy_s(dst_data, dst_size, src_data, src_size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
|
||||
return;
|
||||
}
|
||||
}
|
||||
offset += lengths[i];
|
||||
}
|
||||
|
||||
// Fill the reduced gradient
|
||||
int64_t grad_offset = 0;
|
||||
for (size_t i = 0; i < grad_index; i++) {
|
||||
grad_offset += lengths[i];
|
||||
}
|
||||
int64_t data_size = lengths[grad_index] * sizeof(T);
|
||||
dst_size = data_size;
|
||||
src_size = data_size;
|
||||
dst_data = reduced_data->data() + grad_offset;
|
||||
src_data = const_cast<T *>(grads);
|
||||
MS_EXCEPTION_IF_NULL(dst_data);
|
||||
MS_EXCEPTION_IF_NULL(src_data);
|
||||
auto ret = memcpy_s(dst_data, dst_size, src_data, src_size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
|
||||
return;
|
||||
}
|
||||
|
||||
// Fill the reduced indice
|
||||
int64_t indice_offset = grad_offset + lengths[grad_index];
|
||||
data_size = lengths[indice_index] * sizeof(T);
|
||||
T *indice_data = reduced_data->data() + indice_offset;
|
||||
dst_size = data_size;
|
||||
src_size = data_size;
|
||||
dst_data = indice_data;
|
||||
src_data = indices;
|
||||
MS_EXCEPTION_IF_NULL(dst_data);
|
||||
MS_EXCEPTION_IF_NULL(src_data);
|
||||
ret = memcpy_s(dst_data, dst_size, src_data, src_size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void WorkerProxy<T>::BroadcastSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
|
||||
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced,
|
||||
const std::map<int64_t, int64_t> &attr) {
|
||||
MS_EXCEPTION_IF_NULL(sliced);
|
||||
sliced->resize(server_num_);
|
||||
for (int64_t i = 0; i < server_num_; i++) {
|
||||
sliced->at(i).first = true;
|
||||
sliced->at(i).second = send;
|
||||
expected_result_count_[timestamp] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void WorkerProxy<T>::RoundRobinSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
|
||||
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced,
|
||||
const std::map<int64_t, int64_t> &attr) {
|
||||
MS_EXCEPTION_IF_NULL(sliced);
|
||||
sliced->resize(server_num_);
|
||||
auto keys = send.keys;
|
||||
auto vals = send.vals;
|
||||
auto lens = send.lens;
|
||||
|
||||
int64_t server_id, len;
|
||||
::ps::Key param_key;
|
||||
for (size_t i = 0; i < keys.size(); i++) {
|
||||
param_key = keys[i];
|
||||
server_id = key_to_server_id_[param_key];
|
||||
if (!sliced->at(server_id).first) {
|
||||
sliced->at(server_id).first = true;
|
||||
expected_result_count_[timestamp] += 1;
|
||||
}
|
||||
|
||||
::ps::KVPairs<T> &server_kv_pairs = sliced->at(server_id).second;
|
||||
server_kv_pairs.keys.push_back(param_key);
|
||||
if (vals.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
len = lens[i];
|
||||
int64_t offset = std::accumulate(lens.begin(), lens.begin() + i, 0);
|
||||
auto val_begin = vals.begin() + offset;
|
||||
auto val_end = val_begin + len;
|
||||
|
||||
for (auto iter = val_begin; iter != val_end; iter++) {
|
||||
server_kv_pairs.vals.push_back(*iter);
|
||||
}
|
||||
server_kv_pairs.lens.push_back(len);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void WorkerProxy<T>::WorkerInitEmbeddingSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send,
|
||||
const std::vector<::ps::Range> &,
|
||||
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced,
|
||||
const std::map<int64_t, int64_t> &attrs) {
|
||||
MS_EXCEPTION_IF_NULL(sliced);
|
||||
sliced->resize(server_num_);
|
||||
auto keys = send.keys;
|
||||
auto vals = send.vals;
|
||||
auto lens = send.lens;
|
||||
|
||||
size_t col_cnt = lens[0] / embedding_row_cnt_[keys[0]];
|
||||
const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[keys[0]]);
|
||||
for (size_t i = 0; i < ranges.size(); i++) {
|
||||
size_t offset_begin = ranges[i].begin() * col_cnt;
|
||||
size_t offset_end = (ranges[i].end() + 1) * col_cnt;
|
||||
::ps::KVPairs<T> kvs;
|
||||
kvs.keys = keys;
|
||||
kvs.vals = vals.segment(offset_begin, offset_end);
|
||||
kvs.lens.push_back(offset_end - offset_begin);
|
||||
sliced->at(i).first = true;
|
||||
sliced->at(i).second = kvs;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void WorkerProxy<T>::UpdateEmbeddingSlicer(int timestamp, const ::ps::KVPairs<T> &send,
|
||||
const std::vector<::ps::Range> &,
|
||||
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced,
|
||||
const std::map<int64_t, int64_t> &attrs) {
|
||||
MS_EXCEPTION_IF_NULL(sliced);
|
||||
T *embedding_vals = send.vals.data();
|
||||
int *lookup_ids = send.lens.data();
|
||||
size_t val_size = send.vals.size();
|
||||
size_t id_size = send.lens.size();
|
||||
size_t embedding_dim = val_size / id_size;
|
||||
|
||||
const Key &key = send.keys[0];
|
||||
const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]);
|
||||
sliced->resize(ranges.size());
|
||||
|
||||
for (size_t i = 0; i < ranges.size(); i++) {
|
||||
const ::ps::Range &range = ranges[i];
|
||||
const auto &begin = range.begin();
|
||||
const auto &end = range.end();
|
||||
auto &kvs = sliced->at(i).second;
|
||||
kvs.keys.push_back(key);
|
||||
for (size_t j = 0; j < id_size; j++) {
|
||||
auto lookup_id = static_cast<uint64_t>(lookup_ids[j]);
|
||||
if (lookup_id >= begin && lookup_id <= end) {
|
||||
kvs.keys.push_back(lookup_id);
|
||||
for (size_t k = 0; k < embedding_dim; k++) {
|
||||
kvs.vals.push_back(embedding_vals[j * embedding_dim + k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (kvs.keys.size() <= 1) {
|
||||
sliced->at(i).first = false;
|
||||
} else {
|
||||
sliced->at(i).first = true;
|
||||
expected_result_count_[timestamp] += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void WorkerProxy<T>::ProcessLookupResult(const ::ps::Message &msg) {
|
||||
int64_t ts = msg.meta.timestamp;
|
||||
if (msg.meta.pull) {
|
||||
CHECK_GE(msg.data.size(), (size_t)2);
|
||||
::ps::KVPairs<T> kvs;
|
||||
kvs.keys = msg.data[0];
|
||||
kvs.vals = msg.data[1];
|
||||
if (msg.data.size() > (size_t)2) {
|
||||
kvs.lens = msg.data[2];
|
||||
}
|
||||
mutex_.lock();
|
||||
lookup_results_[ts].push_back(kvs);
|
||||
mutex_.unlock();
|
||||
}
|
||||
if (lookup_customer_->NumResponse(ts) + 1 == server_num_) {
|
||||
const auto &cb = lookup_callbacks_[ts];
|
||||
cb();
|
||||
lookup_callbacks_.erase(ts);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void WorkerProxy<T>::ProcessResponse(const ::ps::Message &msg) {
|
||||
int64_t ts = msg.meta.timestamp;
|
||||
|
||||
if (msg.meta.pull) {
|
||||
CHECK_GE(msg.data.size(), (size_t)2);
|
||||
::ps::KVPairs<T> kvs;
|
||||
kvs.keys = msg.data[0];
|
||||
kvs.vals = msg.data[1];
|
||||
if (msg.data.size() > (size_t)2) {
|
||||
kvs.lens = msg.data[2];
|
||||
}
|
||||
mutex_.lock();
|
||||
int rsp_server_rank = ::ps::Postoffice::Get()->IDtoRank(msg.meta.sender);
|
||||
gathered_response_[ts][rsp_server_rank] = kvs;
|
||||
mutex_.unlock();
|
||||
if (general_customer_->NumResponse(ts) + 1 == server_num_) {
|
||||
const auto &cb = general_callbacks_[ts];
|
||||
cb();
|
||||
general_callbacks_.erase(ts);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void WorkerProxy<T>::Send(::ps::Customer *customer, int64_t timestamp, bool push, bool pull, int64_t cmd,
|
||||
const ::ps::KVPairs<T> &kvs, const Slicer &slicer, std::map<int64_t, int64_t> attrs) {
|
||||
MS_EXCEPTION_IF_NULL(customer);
|
||||
SlicedKVs sliced;
|
||||
slicer(timestamp, kvs, ::ps::Postoffice::Get()->GetServerKeyRanges(), &sliced, attrs);
|
||||
|
||||
for (size_t i = 0; i < sliced.size(); i++) {
|
||||
const auto &s = sliced[i];
|
||||
if (!s.first) continue;
|
||||
::ps::Message msg;
|
||||
msg.meta.app_id = customer->app_id();
|
||||
msg.meta.customer_id = customer->customer_id();
|
||||
msg.meta.request = true;
|
||||
msg.meta.push = push;
|
||||
msg.meta.pull = pull;
|
||||
msg.meta.head = cmd;
|
||||
msg.meta.timestamp = timestamp;
|
||||
msg.meta.recver = ::ps::Postoffice::Get()->ServerRankToID(i);
|
||||
msg.meta.priority = kvs.priority;
|
||||
const auto &kvs = s.second;
|
||||
if (kvs.keys.size()) {
|
||||
msg.AddData(kvs.keys);
|
||||
msg.AddData(kvs.vals);
|
||||
if (kvs.lens.size()) {
|
||||
msg.AddData(kvs.lens);
|
||||
}
|
||||
}
|
||||
::ps::Postoffice::Get()->van()->Send(msg);
|
||||
}
|
||||
}
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_WORKER_PROXY_H_
|
|
@ -24,7 +24,7 @@ namespace mindspore {
|
|||
namespace device {
|
||||
void KernelRuntimeManager::ClearRuntimeResource() {
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
if (ps::Util::IsRoleOfWorker() && ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
||||
if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
||||
ps::ps_cache_instance.SyncEmbeddingTable();
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -78,7 +78,6 @@ export DEVICE_NUM=8
|
|||
export RANK_SIZE=8
|
||||
export RANK_TABLE_FILE=$PATH1
|
||||
|
||||
export MS_COMM_TYPE=zmq
|
||||
export MS_SCHED_NUM=1
|
||||
export MS_WORKER_NUM=$RANK_SIZE
|
||||
export MS_SERVER_NUM=8
|
||||
|
|
|
@ -70,7 +70,6 @@ fi
|
|||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
|
||||
export MS_COMM_TYPE=zmq
|
||||
export MS_SCHED_NUM=1
|
||||
export MS_WORKER_NUM=8
|
||||
export MS_SERVER_NUM=8
|
||||
|
|
|
@ -27,7 +27,6 @@ export EPOCH_SIZE=$2
|
|||
export DEVICE_TARGET=$3
|
||||
export DATASET=$4
|
||||
|
||||
export MS_COMM_TYPE=zmq
|
||||
export MS_SCHED_NUM=1
|
||||
export MS_WORKER_NUM=$RANK_SIZE
|
||||
export LOCAL_WORKER_NUM=$5
|
||||
|
|
|
@ -25,7 +25,6 @@ export RANK_SIZE=$1
|
|||
export EPOCH_SIZE=$2
|
||||
export DEVICE_TARGET=$3
|
||||
export DATASET=$4
|
||||
export MS_COMM_TYPE=zmq
|
||||
export MS_SCHED_NUM=1
|
||||
export MS_WORKER_NUM=$RANK_SIZE
|
||||
export MS_SERVER_NUM=$5
|
||||
|
|
|
@ -23,7 +23,6 @@ self_path=$(dirname "${script_self}")
|
|||
export EPOCH_SIZE=$1
|
||||
export DEVICE_TARGET=$2
|
||||
export DATASET=$3
|
||||
export MS_COMM_TYPE=zmq
|
||||
export MS_SCHED_NUM=1
|
||||
export MS_WORKER_NUM=1
|
||||
export MS_SERVER_NUM=$4
|
||||
|
|
|
@ -15,8 +15,7 @@
|
|||
# ============================================================================
|
||||
|
||||
execute_path=$(pwd)
|
||||
self_path=$(dirname "${script_self}")
|
||||
export MS_COMM_TYPE=zmq
|
||||
self_path=$(dirname $0)
|
||||
export MS_SCHED_NUM=1
|
||||
DEVICE_TARGET=$1
|
||||
export MS_WORKER_NUM=$2
|
||||
|
|
|
@ -15,8 +15,7 @@
|
|||
# ============================================================================
|
||||
|
||||
execute_path=$(pwd)
|
||||
self_path=$(dirname "${script_self}")
|
||||
export MS_COMM_TYPE=zmq
|
||||
self_path=$(dirname $0)
|
||||
export MS_SCHED_NUM=1
|
||||
DEVICE_TARGET=$1
|
||||
DATASET_PATH=$2
|
||||
|
|
|
@ -15,8 +15,7 @@
|
|||
# ============================================================================
|
||||
|
||||
execute_path=$(pwd)
|
||||
self_path=$(dirname "${script_self}")
|
||||
export MS_COMM_TYPE=zmq
|
||||
self_path=$(dirname $0)
|
||||
export MS_SCHED_NUM=1
|
||||
DEVICE_TARGET=$1
|
||||
export MS_WORKER_NUM=$2
|
||||
|
|
|
@ -15,8 +15,7 @@
|
|||
# ============================================================================
|
||||
|
||||
execute_path=$(pwd)
|
||||
self_path=$(dirname "${script_self}")
|
||||
export MS_COMM_TYPE=zmq
|
||||
self_path=$(dirname $0)
|
||||
export MS_SCHED_NUM=1
|
||||
DEVICE_TARGET=$1
|
||||
DATASET_PATH=$2
|
||||
|
|
|
@ -150,6 +150,8 @@ list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/internal/parame
|
|||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/scheduler.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/optimizer_info.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/optimizer_info_builder.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/worker.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/parameter_server.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/gpu/gpu_ps_cache.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc")
|
||||
|
|
|
@ -1,255 +0,0 @@
|
|||
diff -Npur ps-lite-master/include/dmlc/base.h ps-lite-master-new/include/dmlc/base.h
|
||||
--- ps-lite-master/include/dmlc/base.h 2020-02-29 13:59:55.000000000 +0800
|
||||
+++ ps-lite-master-new/include/dmlc/base.h 2020-07-01 11:56:50.444833389 +0800
|
||||
@@ -8,7 +8,7 @@
|
||||
|
||||
/*! \brief whether use glog for logging */
|
||||
#ifndef DMLC_USE_GLOG
|
||||
-#define DMLC_USE_GLOG 0
|
||||
+#define DMLC_USE_GLOG 1
|
||||
#endif
|
||||
|
||||
/*!
|
||||
diff -Npur ps-lite-master/include/dmlc/logging.h ps-lite-master-new/include/dmlc/logging.h
|
||||
--- ps-lite-master/include/dmlc/logging.h 2020-02-29 13:59:55.000000000 +0800
|
||||
+++ ps-lite-master-new/include/dmlc/logging.h 2020-07-08 21:35:33.334584767 +0800
|
||||
@@ -52,7 +52,7 @@ struct Error : public std::runtime_error
|
||||
|
||||
namespace dmlc {
|
||||
inline void InitLogging(const char* argv0) {
|
||||
- google::InitGoogleLogging(argv0);
|
||||
+ //google::InitGoogleLogging(argv0);
|
||||
}
|
||||
} // namespace dmlc
|
||||
|
||||
diff -Npur ps-lite-master/make/deps.mk ps-lite-master-new/make/deps.mk
|
||||
--- ps-lite-master/make/deps.mk 2020-02-29 13:59:55.000000000 +0800
|
||||
+++ ps-lite-master-new/make/deps.mk 2020-06-17 10:35:46.253837426 +0800
|
||||
@@ -1,69 +1,7 @@
|
||||
# Install dependencies
|
||||
-
|
||||
-URL1=https://raw.githubusercontent.com/mli/deps/master/build
|
||||
-URL2=https://github.com/google/protobuf/releases/download/v3.5.1
|
||||
-ifndef WGET
|
||||
-WGET = wget
|
||||
-endif
|
||||
-
|
||||
-# protobuf
|
||||
-PROTOBUF = ${DEPS_PATH}/include/google/protobuf/message.h
|
||||
-${PROTOBUF}:
|
||||
- $(eval FILE=protobuf-cpp-3.5.1.tar.gz)
|
||||
- $(eval DIR=protobuf-3.5.1)
|
||||
- rm -rf $(FILE) $(DIR)
|
||||
- $(WGET) $(URL2)/$(FILE) && tar --no-same-owner -zxf $(FILE)
|
||||
- cd $(DIR) && export CFLAGS=-fPIC && export CXXFLAGS=-fPIC && ./configure -prefix=$(DEPS_PATH) && $(MAKE) && $(MAKE) install
|
||||
- rm -rf $(FILE) $(DIR)
|
||||
-
|
||||
# zmq
|
||||
-ZMQ = ${DEPS_PATH}/include/zmq.h
|
||||
+ZMQ = $(MS_ZMQ_INSTALL_PATH)/lib/libzmq.a
|
||||
|
||||
${ZMQ}:
|
||||
- $(eval FILE=zeromq-4.1.4.tar.gz)
|
||||
- $(eval DIR=zeromq-4.1.4)
|
||||
- rm -rf $(FILE) $(DIR)
|
||||
- $(WGET) $(URL1)/$(FILE) && tar --no-same-owner -zxf $(FILE)
|
||||
- cd $(DIR) && export CFLAGS=-fPIC && export CXXFLAGS=-fPIC && ./configure -prefix=$(DEPS_PATH) --with-libsodium=no --with-libgssapi_krb5=no && $(MAKE) && $(MAKE) install
|
||||
- rm -rf $(FILE) $(DIR)
|
||||
-
|
||||
-# lz4
|
||||
-LZ4 = ${DEPS_PATH}/include/lz4.h
|
||||
-${LZ4}:
|
||||
- $(eval FILE=lz4-r129.tar.gz)
|
||||
- $(eval DIR=lz4-r129)
|
||||
- rm -rf $(FILE) $(DIR)
|
||||
- wget $(URL1)/$(FILE) && tar --no-same-owner -zxf $(FILE)
|
||||
- cd $(DIR) && $(MAKE) && PREFIX=$(DEPS_PATH) $(MAKE) install
|
||||
- rm -rf $(FILE) $(DIR)
|
||||
-
|
||||
-# cityhash
|
||||
-CITYHASH = ${DEPS_PATH}/include/city.h
|
||||
-${CITYHASH}:
|
||||
- $(eval FILE=cityhash-1.1.1.tar.gz)
|
||||
- $(eval DIR=cityhash-1.1.1)
|
||||
- rm -rf $(FILE) $(DIR)
|
||||
- wget $(URL1)/$(FILE) && tar --no-same-owner -zxf $(FILE)
|
||||
- cd $(DIR) && ./configure -prefix=$(DEPS_PATH) --enable-sse4.2 && $(MAKE) CXXFLAGS="-g -O3 -msse4.2" && $(MAKE) install
|
||||
- rm -rf $(FILE) $(DIR)
|
||||
-
|
||||
-
|
||||
-# # gflags
|
||||
-# ${DEPS_PATH}/include/google/gflags.h:
|
||||
-# $(eval FILE=gflags-2.0-no-svn-files.tar.gz)
|
||||
-# $(eval DIR=gflags-2.0)
|
||||
-# rm -rf $(FILE) $(DIR)
|
||||
-# wget $(URL)/$(FILE) && tar -zxf $(FILE)
|
||||
-# cd $(DIR) && ./configure -prefix=$(DEPS_PATH) && $(MAKE) && $(MAKE) install
|
||||
-# rm -rf $(FILE) $(DIR)
|
||||
-# gflags: | ${DEPS_PATH}/include/google/gflags.h
|
||||
+ cd $(MS_ZMQ_DIR) && export CFLAGS="-fPIC -D_GLIBCXX_USE_CXX11_ABI=0" && export CXXFLAGS=-fPIC && ./configure -prefix=$(MS_ZMQ_INSTALL_PATH) --with-libsodium=no --with-libgssapi_krb5=no && $(MAKE) && $(MAKE) install
|
||||
|
||||
-# # glog
|
||||
-# ${DEPS_PATH}/include/glog/logging.h: | ${DEPS_PATH}/include/google/gflags.h
|
||||
-# $(eval FILE=v0.3.4.tar.gz)
|
||||
-# $(eval DIR=glog-0.3.4)
|
||||
-# rm -rf $(FILE) $(DIR)
|
||||
-# wget https://github.com/google/glog/archive/$(FILE) && tar -zxf $(FILE)
|
||||
-# cd $(DIR) && ./configure -prefix=$(DEPS_PATH) --with-gflags=$(DEPS_PATH) && $(MAKE) && $(MAKE) install
|
||||
-# rm -rf $(FILE) $(DIR)
|
||||
-# glog: | ${DEPS_PATH}/include/glog/logging.h
|
||||
diff -Npur ps-lite-master/make/ps.mk ps-lite-master-new/make/ps.mk
|
||||
--- ps-lite-master/make/ps.mk 2020-02-29 13:59:55.000000000 +0800
|
||||
+++ ps-lite-master-new/make/ps.mk 2020-06-05 09:28:35.337740291 +0800
|
||||
@@ -9,5 +9,5 @@ ifeq ($(USE_KEY32), 1)
|
||||
ADD_CFLAGS += -DUSE_KEY32=1
|
||||
endif
|
||||
|
||||
-PS_LDFLAGS_SO = -L$(DEPS_PATH)/lib -lprotobuf-lite -lzmq
|
||||
-PS_LDFLAGS_A = $(addprefix $(DEPS_PATH)/lib/, libprotobuf-lite.a libzmq.a)
|
||||
+PS_LDFLAGS_SO = -L$(MS_ZMQ_INSTALL_PATH)/lib -lzmq -L$(MS_PROTO_LIB_DIR) -lprotobuf-lite
|
||||
+PS_LDFLAGS_A = $(addprefix $(MS_ZMQ_INSTALL_PATH)/lib -L$(MS_PROTO_LIB_DIR), libprotobuf-lite.a libzmq.a)
|
||||
diff -Npur ps-lite-master/Makefile ps-lite-master-new/Makefile
|
||||
--- ps-lite-master/Makefile 2020-02-29 13:59:55.000000000 +0800
|
||||
+++ ps-lite-master-new/Makefile 2020-06-17 11:09:20.240322660 +0800
|
||||
@@ -12,13 +12,24 @@ ifndef DEPS_PATH
|
||||
DEPS_PATH = $(shell pwd)/deps
|
||||
endif
|
||||
|
||||
+MS_PROTO_DIR = @protobuf_DIRPATH@
|
||||
+MS_GLOG_DIR = @glog_DIRPATH@
|
||||
+MS_ZMQ_DIR = @zeromq_DIRPATH@
|
||||
+
|
||||
+MS_PROTO_LIB_DIR = @protobuf_LIBPATH@
|
||||
+MS_GLOG_LIB_DIR = @glog_LIBPATH@
|
||||
+MS_ZMQ_INSTALL_PATH = $(MS_ZMQ_DIR)/zmq_install
|
||||
|
||||
ifndef PROTOC
|
||||
-PROTOC = ${DEPS_PATH}/bin/protoc
|
||||
+PROTOC = $(MS_PROTO_DIR)/bin/protoc
|
||||
endif
|
||||
|
||||
-INCPATH = -I./src -I./include -I$(DEPS_PATH)/include
|
||||
-CFLAGS = -std=c++11 -msse2 -fPIC -O3 -ggdb -Wall -finline-functions $(INCPATH) $(ADD_CFLAGS)
|
||||
+INCPATH = -I./src -I./include -I$(MS_ZMQ_INSTALL_PATH)/include
|
||||
+INCPATH += -I$(MS_PROTO_DIR)/include
|
||||
+INCPATH += -I$(MS_GLOG_DIR)/include
|
||||
+
|
||||
+CXXFLAGS = -D_GLIBCXX_USE_CXX11_ABI=0
|
||||
+CFLAGS = -std=c++11 -fPIC -O3 -ggdb -Wall -finline-functions $(INCPATH) $(ADD_CFLAGS) -D_GLIBCXX_USE_CXX11_ABI=0
|
||||
LIBS = -pthread
|
||||
|
||||
ifdef USE_IBVERBS
|
||||
@@ -30,6 +41,7 @@ ifdef ASAN
|
||||
CFLAGS += -fsanitize=address -fno-omit-frame-pointer -fno-optimize-sibling-calls
|
||||
endif
|
||||
|
||||
+LIBS += -L$(MS_GLOG_LIB_DIR) -lglog
|
||||
|
||||
all: ps test
|
||||
|
||||
@@ -51,9 +63,9 @@ build/libps.a: $(OBJS)
|
||||
build/%.o: src/%.cc ${ZMQ} src/meta.pb.h
|
||||
@mkdir -p $(@D)
|
||||
$(CXX) $(INCPATH) -std=c++11 -MM -MT build/$*.o $< >build/$*.d
|
||||
- $(CXX) $(CFLAGS) $(LIBS) -c $< -o $@
|
||||
+ $(CXX) $(CFLAGS) $(CXXFLAGS) $(LIBS) -c $< -o $@
|
||||
|
||||
-src/%.pb.cc src/%.pb.h : src/%.proto ${PROTOBUF}
|
||||
+src/%.pb.cc src/%.pb.h : src/%.proto
|
||||
$(PROTOC) --cpp_out=./src --proto_path=./src $<
|
||||
|
||||
-include build/*.d
|
||||
diff -Npur ps-lite-master/src/ibverbs_van.h ps-lite-master-new/src/ibverbs_van.h
|
||||
--- ps-lite-master/src/ibverbs_van.h 2020-02-29 13:59:55.000000000 +0800
|
||||
+++ ps-lite-master-new/src/ibverbs_van.h 2020-06-02 20:52:11.076230014 +0800
|
||||
@@ -145,15 +145,15 @@ class SimpleMempool {
|
||||
total_allocated_size += new_mem_size;
|
||||
}
|
||||
|
||||
- CHECK_NE(free_list.end(), it) << "Not enough memory";
|
||||
+ //CHECK_NE(free_list.end(), it) << "Not enough memory";
|
||||
CHECK_GE(it->first, proper_size);
|
||||
|
||||
char *addr = it->second;
|
||||
size_t space_left = it->first - proper_size;
|
||||
|
||||
free_list.erase(it);
|
||||
- CHECK_EQ(used_list.find(addr), used_list.end())
|
||||
- << "Address is already allocated";
|
||||
+ //CHECK_EQ(used_list.find(addr), used_list.end())
|
||||
+ //<< "Address is already allocated";
|
||||
|
||||
used_list.emplace(addr, proper_size);
|
||||
|
||||
@@ -173,8 +173,8 @@ class SimpleMempool {
|
||||
std::lock_guard<std::mutex> lk(mu_);
|
||||
|
||||
auto it = used_list.find(addr);
|
||||
- CHECK_NE(used_list.end(), it)
|
||||
- << "Cannot find info about address: " << (uintptr_t)addr;
|
||||
+ //CHECK_NE(used_list.end(), it)
|
||||
+ //<< "Cannot find info about address: " << (uintptr_t)addr;
|
||||
|
||||
size_t size = it->second;
|
||||
used_list.erase(it);
|
||||
@@ -208,7 +208,7 @@ class SimpleMempool {
|
||||
// Convert the memory address to its associated RDMA memory region
|
||||
inline struct ibv_mr *Addr2MR(char *addr) {
|
||||
auto it = mr_list.lower_bound(addr);
|
||||
- CHECK_NE(it, mr_list.end()) << "cannot find the associated memory region";
|
||||
+ //CHECK_NE(it, mr_list.end()) << "cannot find the associated memory region";
|
||||
return it->second;
|
||||
}
|
||||
};
|
||||
@@ -330,7 +330,7 @@ class AddressPool {
|
||||
CHECK(ptr);
|
||||
uint32_t idx = indices_.front();
|
||||
indices_.pop();
|
||||
- CHECK_EQ(table_[idx], nullptr);
|
||||
+ //CHECK_EQ(table_[idx], nullptr);
|
||||
table_[idx] = ptr;
|
||||
return idx;
|
||||
}
|
||||
@@ -636,7 +636,7 @@ class IBVerbsVan : public Van {
|
||||
PBMeta meta;
|
||||
PackMetaPB(msg.meta, &meta);
|
||||
|
||||
- CHECK_NE(endpoints_.find(remote_id), endpoints_.end());
|
||||
+ //CHECK_NE(endpoints_.find(remote_id), endpoints_.end());
|
||||
Endpoint *endpoint = endpoints_[remote_id].get();
|
||||
MessageBuffer *msg_buf = new MessageBuffer();
|
||||
|
||||
diff -Npur ps-lite-master/src/van.cc ps-lite-master-new/src/van.cc
|
||||
--- ps-lite-master/src/van.cc 2020-02-29 13:59:55.000000000 +0800
|
||||
+++ ps-lite-master-new/src/van.cc 2020-06-02 20:52:43.330405828 +0800
|
||||
@@ -448,6 +448,7 @@ void Van::PackMetaPB(const Meta& meta, P
|
||||
if (meta.timestamp != Meta::kEmpty) pb->set_timestamp(meta.timestamp);
|
||||
if (meta.body.size()) pb->set_body(meta.body);
|
||||
pb->set_push(meta.push);
|
||||
+ pb->set_pull(meta.pull);
|
||||
pb->set_request(meta.request);
|
||||
pb->set_simple_app(meta.simple_app);
|
||||
pb->set_priority(meta.priority);
|
||||
diff -Npur ps-lite-master/tests/test.mk ps-lite-master-new/tests/test.mk
|
||||
--- ps-lite-master/tests/test.mk 2020-02-29 13:59:55.000000000 +0800
|
||||
+++ ps-lite-master-new/tests/test.mk 2020-06-16 19:15:06.025087897 +0800
|
||||
@@ -1,10 +1,10 @@
|
||||
-TEST_SRC = $(wildcard tests/test_*.cc)
|
||||
-TEST = $(patsubst tests/test_%.cc, tests/test_%, $(TEST_SRC))
|
||||
+#TEST_SRC = $(wildcard tests/test_*.cc)
|
||||
+#TEST = $(patsubst tests/test_%.cc, tests/test_%, $(TEST_SRC))
|
||||
|
||||
-# -ltcmalloc_and_profiler
|
||||
-LDFLAGS = -Wl,-rpath,$(DEPS_PATH)/lib $(PS_LDFLAGS_SO) -pthread
|
||||
-tests/% : tests/%.cc build/libps.a
|
||||
- $(CXX) $(CFLAGS) $(LIBS) -MM -MT tests/$* $< >tests/$*.d
|
||||
- $(CXX) $(CFLAGS) $(LIBS) -o $@ $(filter %.cc %.a, $^) $(LDFLAGS)
|
||||
-
|
||||
--include tests/*.d
|
||||
+## -ltcmalloc_and_profiler
|
||||
+#LDFLAGS = -Wl,-rpath,$(DEPS_PATH)/lib $(PS_LDFLAGS_SO) -pthread
|
||||
+#tests/% : tests/%.cc build/libps.a
|
||||
+# $(CXX) $(CFLAGS) $(LIBS) -MM -MT tests/$* $< >tests/$*.d
|
||||
+# $(CXX) $(CFLAGS) $(LIBS) -o $@ $(filter %.cc %.a, $^) $(LDFLAGS)
|
||||
+#
|
||||
+#-include tests/*.d
|
Loading…
Reference in New Issue