forked from mindspore-Ecosystem/mindspore
Add server code part2
This commit is contained in:
parent
cb6e055736
commit
12f95b51f4
|
@ -35,7 +35,7 @@ function(ms_build_flatbuffers source_schema_files
|
|||
set(total_schema_dirs -I ${schema_dir} ${total_schema_dirs})
|
||||
endforeach()
|
||||
|
||||
foreach(schema ${source_schema_files})
|
||||
foreach(schema IN LISTS ${source_schema_files})
|
||||
get_filename_component(filename ${schema} NAME_WE)
|
||||
if(NOT ${generated_output_dir} STREQUAL "")
|
||||
set(generated_file ${generated_output_dir}/${filename}_generated.h)
|
||||
|
|
|
@ -212,7 +212,7 @@ if(ENABLE_GPU)
|
|||
)
|
||||
endif()
|
||||
|
||||
if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))
|
||||
if(ENABLE_CPU AND NOT WIN32)
|
||||
install(
|
||||
TARGETS ps_cache
|
||||
DESTINATION ${INSTALL_LIB_DIR}
|
||||
|
|
|
@ -373,7 +373,7 @@ elseif(CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
|||
target_link_libraries(mindspore mindspore_gvar)
|
||||
target_link_libraries(_c_expression PRIVATE -Wl,-force_load mindspore mindspore_core -Wl,-noall_load)
|
||||
else()
|
||||
if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))
|
||||
if(ENABLE_CPU AND NOT WIN32)
|
||||
target_link_libraries(mindspore proto_input mindspore::protobuf
|
||||
mindspore::event mindspore::event_pthreads mindspore::event_openssl mindspore::json)
|
||||
target_link_libraries(mindspore -Wl,--no-as-needed mindspore::event_core ps_cache)
|
||||
|
|
|
@ -75,7 +75,7 @@ if(ENABLE_CPU)
|
|||
endif()
|
||||
endif()
|
||||
|
||||
if(NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)))
|
||||
if(NOT ENABLE_CPU OR WIN32)
|
||||
list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/apply_momentum_ps_kernel.cc")
|
||||
list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/embedding_look_up_proxy_kernel.cc")
|
||||
list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/embedding_look_up_ps_kernel.cc")
|
||||
|
|
|
@ -421,7 +421,7 @@ void AscendSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_gra
|
|||
size = abstract::ShapeSize(shape_tmp) * abstract::TypeIdSize(tensor->data_type());
|
||||
}
|
||||
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) {
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
const std::string ¶m_name = input_node->fullname_with_scope();
|
||||
if (ps::ps_cache_instance.IsHashTable(param_name)) {
|
||||
continue;
|
||||
|
|
|
@ -33,7 +33,7 @@
|
|||
#include "debug/anf_ir_dump.h"
|
||||
#include "debug/dump_proto.h"
|
||||
#include "debug/data_dump/dump_json_parser.h"
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#include "ps/util.h"
|
||||
#include "ps/ps_context.h"
|
||||
#endif
|
||||
|
@ -74,7 +74,7 @@ void CPUSession::Reorder(std::vector<CNodePtr> *node_list) { AnfAlgo::ReorderPos
|
|||
void CPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode && ps::PSContext::instance()->is_ps_mode()) {
|
||||
|
@ -174,7 +174,7 @@ void CPUSession::PreExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_grap
|
|||
MS_LOG(INFO) << "Bind input output address";
|
||||
runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs);
|
||||
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
InitPSParamAndOptim(kernel_graph, inputs);
|
||||
#endif
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include "utils/comm_manager.h"
|
||||
#include "utils/scoped_long_running.h"
|
||||
#include "pybind_api/ir/tensor_py.h"
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#include "ps/ps_cache/ps_cache_manager.h"
|
||||
#endif
|
||||
|
||||
|
|
|
@ -43,7 +43,7 @@
|
|||
#include "debug/common.h"
|
||||
#include "utils/trace_base.h"
|
||||
#include "frontend/parallel/context.h"
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#include "ps/ps_cache/ps_cache_manager.h"
|
||||
#include "ps/constants.h"
|
||||
#include "ps/util.h"
|
||||
|
@ -2357,7 +2357,7 @@ void SessionBasic::DumpGraph(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
|||
#endif
|
||||
}
|
||||
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
void SessionBasic::InitPsWorker(const KernelGraphPtr &kernel_graph) {
|
||||
if (!ps::PSContext::instance()->is_worker()) {
|
||||
return;
|
||||
|
|
|
@ -244,7 +244,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
std::vector<uint32_t> GetAllReduceSplitIndex();
|
||||
virtual std::string GetCommWorldGroup() { return std::string(); }
|
||||
void DumpGraph(const std::shared_ptr<KernelGraph> &kernel_graph);
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
void CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) const;
|
||||
void GetBatchElements(const AnfNodePtr &kernel_node) const;
|
||||
void InitPsWorker(const KernelGraphPtr &kernel_graph);
|
||||
|
@ -263,7 +263,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
std::shared_ptr<Debugger> debugger_;
|
||||
#endif
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
bool initialized_ps_cache_{false};
|
||||
#endif
|
||||
};
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
#include "frontend/parallel/device_matrix.h"
|
||||
#include "frontend/parallel/graph_util/generate_graph.h"
|
||||
#include "frontend/parallel/context.h"
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#include "ps/ps_cache/ps_cache_manager.h"
|
||||
#include "utils/ms_context.h"
|
||||
#endif
|
||||
|
@ -160,7 +160,7 @@ Status GatherPInfo::GetAttrs() {
|
|||
if (std::find(inputs_shape_[1].begin(), inputs_shape_[1].end(), -1) != inputs_shape_[1].end()) {
|
||||
dynamic_shape_indices_ = true;
|
||||
}
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
|
||||
std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
|
||||
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
|
||||
|
@ -617,7 +617,7 @@ Status GatherPInfo::InferBias() {
|
|||
rank = rank % (params_strategy[0] * params_strategy[1]);
|
||||
}
|
||||
}
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
||||
bias_ = static_cast<int64_t>(ps::PsCacheManager::GetInstance().cache_indices_lower_bound());
|
||||
return SUCCESS;
|
||||
|
|
|
@ -28,7 +28,7 @@
|
|||
#include "frontend/parallel/strategy.h"
|
||||
#include "frontend/parallel/context.h"
|
||||
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#include "ps/ps_cache/ps_cache_manager.h"
|
||||
#endif
|
||||
|
||||
|
@ -192,7 +192,7 @@ Status UniqueInfo::GenerateStrategies(int64_t stage_id) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
Status UniqueInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
||||
GenerateGraph gen_g = GenerateGraph();
|
||||
if (gen_g.Init(cnode) != SUCCESS) {
|
||||
|
@ -230,7 +230,7 @@ Status UniqueInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
|||
#endif
|
||||
|
||||
ReplaceGraphPtr UniqueInfo::replace_graph(const CNodePtr &cnode) {
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
||||
auto inputs = cnode->inputs();
|
||||
if (inputs.empty()) {
|
||||
|
|
|
@ -51,7 +51,7 @@ class UniqueInfo : public OperatorInfo {
|
|||
Status InferMirrorOps() override;
|
||||
Status InferForwardCommunication() override { return SUCCESS; }
|
||||
Status InferAsLossDivisor() override { return SUCCESS; }
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
Status ComputeReplaceGraph(const CNodePtr &cnode);
|
||||
#endif
|
||||
|
||||
|
|
|
@ -47,14 +47,14 @@
|
|||
#include "ir/anf.h"
|
||||
#include "ir/param_info.h"
|
||||
#include "ir/tensor.h"
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#include "ps/util.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
if (ps::Util::IsRoleOfPServer() || ps::Util::IsRoleOfScheduler()) {
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -46,7 +46,7 @@
|
|||
#include "utils/ms_context.h"
|
||||
#include "utils/symbolic.h"
|
||||
#include "mindspore/core/utils/parallel_node_check.h"
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#include "ps/util.h"
|
||||
#include "ps/ps_context.h"
|
||||
#endif
|
||||
|
@ -3553,7 +3553,7 @@ static void HandleFullySplitParameters(const FuncGraphPtr &root) {
|
|||
}
|
||||
|
||||
bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) {
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) {
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -295,7 +295,7 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Windows")
|
|||
target_link_libraries(_c_dataengine PRIVATE _c_mindrecord ${MINDRECORD_LINK_OBJECT} mindspore::sqlite)
|
||||
else()
|
||||
target_link_libraries(_c_dataengine PRIVATE _c_mindrecord)
|
||||
if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))
|
||||
if(ENABLE_CPU AND NOT WIN32)
|
||||
if(${ENABLE_IBVERBS} STREQUAL "ON")
|
||||
target_link_libraries(_c_dataengine PRIVATE ibverbs rdmacm)
|
||||
endif()
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
add_subdirectory(perf EXCLUDE_FROM_ALL)
|
||||
include_directories("${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache")
|
||||
set(MD_FLATBUFFER_OU "${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache")
|
||||
ms_build_flatbuffers("de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${MD_FLATBUFFER_OU})
|
||||
set(FBS_FILES de_tensor.fbs)
|
||||
ms_build_flatbuffers(FBS_FILES ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${MD_FLATBUFFER_OU})
|
||||
|
||||
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
|
|
|
@ -43,7 +43,7 @@
|
|||
#include "vm/transform.h"
|
||||
#include "parse/python_adapter.h"
|
||||
#include "frontend/optimizer/py_pass_manager.h"
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#include "ps/parameter_server.h"
|
||||
#include "ps/scheduler.h"
|
||||
#include "ps/worker.h"
|
||||
|
@ -606,7 +606,7 @@ bool ExecuteAction(const ResourcePtr &res) {
|
|||
return true;
|
||||
}
|
||||
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
bool StartPSWorkerAction(const ResourcePtr &res) {
|
||||
ps::Worker::GetInstance().Run();
|
||||
return true;
|
||||
|
@ -782,7 +782,7 @@ std::vector<ActionItem> VmPipeline() {
|
|||
actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction));
|
||||
|
||||
actions.emplace_back(std::make_pair("validate", ValidateAction));
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
if (ps::PSContext::instance()->is_worker()) {
|
||||
actions.emplace_back(std::make_pair("worker", StartPSWorkerAction));
|
||||
}
|
||||
|
@ -796,7 +796,7 @@ std::vector<ActionItem> VmPipeline() {
|
|||
return actions;
|
||||
}
|
||||
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
std::vector<ActionItem> PServerPipeline() {
|
||||
auto actions = CommonPipeline();
|
||||
actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));
|
||||
|
|
|
@ -34,7 +34,7 @@
|
|||
#else
|
||||
#include "runtime/device/gpu/distribution/collective_fake_init.h"
|
||||
#endif
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#include "ps/util.h"
|
||||
#endif
|
||||
#include "ps/ps_context.h"
|
||||
|
|
|
@ -42,7 +42,7 @@
|
|||
#include "pipeline/jit/pipeline_split.h"
|
||||
#include "pipeline/jit/static_analysis/auto_monad.h"
|
||||
#include "frontend/optimizer/irpass/gradient_eliminate.h"
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#include "ps/util.h"
|
||||
#include "ps/ps_context.h"
|
||||
#endif
|
||||
|
@ -407,7 +407,7 @@ bool AddRecomputationPass(const ResourcePtr &res) {
|
|||
}
|
||||
|
||||
bool AddCacheEmbeddingPass(const ResourcePtr &res) {
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
if (ps::PSContext::instance()->is_ps_mode()) {
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -49,7 +49,7 @@
|
|||
#include "utils/shape_utils.h"
|
||||
#include "utils/info.h"
|
||||
#include "load_mindir/load_model.h"
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#include "ps/constants.h"
|
||||
#include "ps/util.h"
|
||||
#include "ps/worker.h"
|
||||
|
@ -528,7 +528,7 @@ std::vector<ActionItem> GetPipeline(const ResourcePtr &resource, const std::stri
|
|||
|
||||
std::string backend = MsContext::GetInstance()->backend_policy();
|
||||
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
if (ps::PSContext::instance()->is_server()) {
|
||||
resource->results()[kBackend] = compile::CreateBackend();
|
||||
return PServerPipeline();
|
||||
|
@ -961,7 +961,7 @@ bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t ba
|
|||
bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size,
|
||||
const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
|
||||
const std::vector<int64_t> &input_indexes, bool need_run) {
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
if ((ps::PSContext::instance()->is_ps_mode()) && (!ps::PSContext::instance()->is_worker())) {
|
||||
return true;
|
||||
}
|
||||
|
@ -1027,7 +1027,7 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc
|
|||
}
|
||||
ConfigManager::GetInstance().set_iter_num(size);
|
||||
// PS cache does not support loop sink.
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
||||
ps::PsDataPrefetch::GetInstance().CreateDataChannel(queue_name, LongToSize(size));
|
||||
ConfigManager::GetInstance().set_iter_num(1);
|
||||
|
@ -1150,7 +1150,7 @@ void FinalizeBackend() {
|
|||
void ClearResAtexit() {
|
||||
MS_LOG(DEBUG) << "Pipeline clear all resource";
|
||||
pynative::ClearPyNativeSession();
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
if (ps::PSContext::instance()->is_ps_mode() && ps::PSContext::instance()->is_worker()) {
|
||||
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
||||
ps::ps_cache_instance.Finalize();
|
||||
|
|
|
@ -1,6 +1,13 @@
|
|||
file(GLOB_RECURSE _PS_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
|
||||
if(NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)))
|
||||
set(SERVER_FLATBUFFER_OUTPUT "${CMAKE_BINARY_DIR}/schema")
|
||||
set(FBS_FILES
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../schema/cipher.fbs
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../schema/fl_job.fbs
|
||||
)
|
||||
ms_build_flatbuffers(FBS_FILES ${CMAKE_CURRENT_SOURCE_DIR}../../schema generated_fbs_files ${SERVER_FLATBUFFER_OUTPUT})
|
||||
|
||||
if(NOT ENABLE_CPU OR WIN32)
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "optimizer_info_builder.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "optimizer_info.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "scheduler.cc")
|
||||
|
@ -12,11 +19,6 @@ if(NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)))
|
|||
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_client.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_message_handler.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_server.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/communicator_base.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/http_communicator.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_communicator.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/http_msg_handler.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_msg_handler.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "core/node.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "core/node_manager.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc")
|
||||
|
@ -39,18 +41,32 @@ if(NOT ENABLE_GPU)
|
|||
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/gpu/gpu_ps_cache.cc")
|
||||
endif()
|
||||
|
||||
if(WIN32 OR NOT ENABLE_CPU)
|
||||
if(NOT ENABLE_CPU OR WIN32)
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/communicator_base.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/http_communicator.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_communicator.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/http_msg_handler.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_msg_handler.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/apply_momentum_kernel.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/aggregation_kernel_factory.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/dense_grad_accum_kernel.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/optimizer_kernel_factory.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/round_kernel_factory.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/round_kernel.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/start_fl_job_kernel.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/params_info.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/consistent_hash_ring.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/iteration_timer.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/local_meta_storage.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/local_meta_store.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/memory_register.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/parameter_aggregator.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/executor.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/collective_ops_impl.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/distributed_count_service.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/distributed_metadata_store.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/iteration.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/model_store.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/round.cc")
|
||||
endif()
|
||||
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_prefetch.cc")
|
||||
|
@ -59,3 +75,5 @@ add_subdirectory(ps_cache)
|
|||
|
||||
set_property(SOURCE ${_PS_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS)
|
||||
add_library(_mindspore_ps_obj OBJECT ${_PS_SRC_FILES})
|
||||
add_dependencies(_mindspore_ps_obj generated_fbs_files)
|
||||
target_link_libraries(_mindspore_ps_obj mindspore::flatbuffers)
|
||||
|
|
|
@ -34,13 +34,26 @@
|
|||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
enum class TcpUserCommand { kPush, kPull, kCount, kReachThreshold, kResetCount, kGetValue, kPutValue, kCounterEvent };
|
||||
enum class TcpUserCommand {
|
||||
kPush,
|
||||
kPull,
|
||||
kCount,
|
||||
kReachThreshold,
|
||||
kResetCount,
|
||||
kGetMetadata,
|
||||
kUpdateMetadata,
|
||||
kCounterEvent
|
||||
};
|
||||
|
||||
const std::unordered_map<TcpUserCommand, std::string> kUserCommandToMsgType = {
|
||||
{TcpUserCommand::kPush, "push"}, {TcpUserCommand::kPull, "pull"},
|
||||
{TcpUserCommand::kCount, "count"}, {TcpUserCommand::kReachThreshold, "reachThreshold"},
|
||||
{TcpUserCommand::kResetCount, "resetCnt"}, {TcpUserCommand::kGetValue, "getValue"},
|
||||
{TcpUserCommand::kPutValue, "putValue"}, {TcpUserCommand::kCounterEvent, "counterEvent"},
|
||||
{TcpUserCommand::kPush, "push"},
|
||||
{TcpUserCommand::kPull, "pull"},
|
||||
{TcpUserCommand::kCount, "count"},
|
||||
{TcpUserCommand::kReachThreshold, "countReachThreshold"},
|
||||
{TcpUserCommand::kResetCount, "resetCnt"},
|
||||
{TcpUserCommand::kGetMetadata, "getMetadata"},
|
||||
{TcpUserCommand::kUpdateMetadata, "updateMetadata"},
|
||||
{TcpUserCommand::kCounterEvent, "counterEvent"},
|
||||
};
|
||||
|
||||
class TcpCommunicator : public CommunicatorBase {
|
||||
|
|
|
@ -0,0 +1,155 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
syntax = "proto3";
|
||||
package mindspore.ps;
|
||||
|
||||
message CollectiveData {
|
||||
bytes data = 1;
|
||||
}
|
||||
|
||||
message CountRequest {
|
||||
string name = 1;
|
||||
string id = 2;
|
||||
}
|
||||
|
||||
message CountResponse {
|
||||
bool result = 1;
|
||||
string reason = 2;
|
||||
}
|
||||
|
||||
message CountReachThresholdRequest {
|
||||
string name = 1;
|
||||
}
|
||||
|
||||
message CountReachThresholdResponse {
|
||||
bool is_enough = 1;
|
||||
}
|
||||
|
||||
message ResetCounterRequest {
|
||||
string name = 1;
|
||||
}
|
||||
|
||||
message UpdateMetadataRequest {
|
||||
string name = 1;
|
||||
bytes value = 2;
|
||||
}
|
||||
|
||||
message GetMetadataRequest {
|
||||
string name = 1;
|
||||
}
|
||||
|
||||
message GetMetadataResponse {
|
||||
bytes value = 1;
|
||||
}
|
||||
|
||||
enum CounterEventType {
|
||||
FIRST_CNT = 0;
|
||||
LAST_CNT = 1;
|
||||
}
|
||||
|
||||
message CounterEvent {
|
||||
CounterEventType type = 1;
|
||||
string name = 2;
|
||||
bytes data = 3;
|
||||
}
|
||||
|
||||
message FLId {
|
||||
string fl_id = 1;
|
||||
}
|
||||
|
||||
message UpdateModelClientList {
|
||||
repeated string fl_id = 1;
|
||||
}
|
||||
|
||||
message DeviceMeta {
|
||||
string fl_name = 1;
|
||||
string fl_id = 2;
|
||||
uint64 data_size = 3;
|
||||
}
|
||||
|
||||
message FLIdToDeviceMeta {
|
||||
map<string, DeviceMeta> fl_id_to_meta = 1;
|
||||
}
|
||||
|
||||
message UpdateModelThreshold {
|
||||
uint64 threshold = 1;
|
||||
}
|
||||
|
||||
message ClientShares {
|
||||
map<string, SharesPb> client_secret_shares = 1;
|
||||
}
|
||||
|
||||
message PairClientShares {
|
||||
string fl_id = 1;
|
||||
SharesPb client_shares = 2;
|
||||
}
|
||||
|
||||
message ClientKeys {
|
||||
map<string, KeysPb> client_keys = 1;
|
||||
}
|
||||
|
||||
message ClientNoises {
|
||||
OneClientNoises one_client_noises = 1;
|
||||
}
|
||||
|
||||
message PairClientKeys {
|
||||
string fl_id = 1;
|
||||
KeysPb client_keys = 2;
|
||||
}
|
||||
|
||||
message OneClientNoises {
|
||||
repeated float noise = 1;
|
||||
}
|
||||
|
||||
message ClientShareStr {
|
||||
string fl_id = 1;
|
||||
bytes share = 2; // todo: verify the correctness
|
||||
int32 index = 3;
|
||||
}
|
||||
|
||||
message SharesPb {
|
||||
repeated ClientShareStr clientsharestrs = 1;
|
||||
}
|
||||
|
||||
message KeysPb {
|
||||
repeated bytes key = 1;
|
||||
}
|
||||
|
||||
message PBMetadata {
|
||||
oneof value {
|
||||
DeviceMeta device_meta = 1;
|
||||
FLIdToDeviceMeta device_metas = 2;
|
||||
|
||||
FLId fl_id = 3;
|
||||
UpdateModelClientList client_list = 4;
|
||||
|
||||
UpdateModelThreshold update_model_threshold = 5;
|
||||
|
||||
PairClientShares pair_client_shares = 6;
|
||||
ClientShares client_shares = 7;
|
||||
|
||||
PairClientKeys pair_client_keys = 8;
|
||||
ClientKeys client_keys = 9;
|
||||
|
||||
OneClientNoises one_client_noises = 10;
|
||||
ClientNoises client_noises = 11;
|
||||
}
|
||||
}
|
||||
|
||||
message PBMetadataWithName {
|
||||
string name = 1;
|
||||
PBMetadata metadata = 2;
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))
|
||||
if(ENABLE_CPU AND NOT WIN32)
|
||||
file(GLOB_RECURSE _PS_CACHE_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ps_data/*.cc")
|
||||
set_property(SOURCE ${_PS_CACHE_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS)
|
||||
add_library(ps_cache SHARED ${_PS_CACHE_SRC_FILES})
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#include "utils/log_adapter.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "backend/kernel_compiler/kernel.h"
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#include "ps/ps_cache/ps_cache_manager.h"
|
||||
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
|
||||
#endif
|
||||
|
@ -68,7 +68,7 @@ void PSContext::Reset() {
|
|||
is_worker_ = false;
|
||||
is_pserver_ = false;
|
||||
is_sched_ = false;
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
||||
ps_cache_instance.Finalize();
|
||||
set_cache_enable(false);
|
||||
|
@ -108,46 +108,62 @@ int PSContext::ps_rank_id() const { return rank_id_; }
|
|||
|
||||
void PSContext::InsertHashTableSize(const std::string ¶m_name, size_t cache_vocab_size, size_t embedding_size,
|
||||
size_t vocab_size) const {
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
ps_cache_instance.InsertHashTableSize(param_name, cache_vocab_size, embedding_size, vocab_size);
|
||||
#endif
|
||||
}
|
||||
|
||||
void PSContext::ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name,
|
||||
size_t cache_vocab_size, size_t embedding_size) const {
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
ps_cache_instance.ReInsertHashTableSize(new_param_name, cur_param_name, cache_vocab_size, embedding_size);
|
||||
#endif
|
||||
}
|
||||
|
||||
void PSContext::InsertWeightInitInfo(const std::string ¶m_name, size_t global_seed, size_t op_seed) const {
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
ps_cache_instance.InsertWeightInitInfo(param_name, global_seed, op_seed);
|
||||
#endif
|
||||
}
|
||||
|
||||
void PSContext::InsertAccumuInitInfo(const std::string ¶m_name, float init_val) const {
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
ps_cache_instance.InsertAccumuInitInfo(param_name, init_val);
|
||||
#endif
|
||||
}
|
||||
|
||||
void PSContext::CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) const {
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
ps_cache_instance.CloneHashTable(dest_param_name, src_param_name);
|
||||
#endif
|
||||
}
|
||||
|
||||
void PSContext::set_cache_enable(bool cache_enable) const {
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
PsDataPrefetch::GetInstance().set_cache_enable(cache_enable);
|
||||
#endif
|
||||
}
|
||||
|
||||
void PSContext::set_rank_id(int rank_id) const {
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
ps_cache_instance.set_rank_id(rank_id);
|
||||
#endif
|
||||
}
|
||||
|
||||
void PSContext::set_fl_name(const std::string &fl_name) { fl_name_ = fl_name; }
|
||||
|
||||
const std::string &PSContext::fl_name() const { return fl_name_; }
|
||||
|
||||
void PSContext::set_fl_iteration_num(uint64_t fl_iteration_num) { fl_iteration_num_ = fl_iteration_num; }
|
||||
|
||||
uint64_t PSContext::fl_iteration_num() const { return fl_iteration_num_; }
|
||||
|
||||
void PSContext::set_client_epoch_num(uint64_t client_epoch_num) { client_epoch_num_ = client_epoch_num; }
|
||||
|
||||
uint64_t PSContext::client_epoch_num() const { return client_epoch_num_; }
|
||||
|
||||
void PSContext::set_client_batch_size(uint64_t client_batch_size) { client_batch_size_ = client_batch_size; }
|
||||
|
||||
uint64_t PSContext::client_batch_size() const { return client_batch_size_; }
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -60,6 +60,19 @@ class PSContext {
|
|||
void set_cache_enable(bool cache_enable) const;
|
||||
void set_rank_id(int rank_id) const;
|
||||
|
||||
// Setter and getter for federated learning.
|
||||
void set_fl_name(const std::string &fl_name);
|
||||
const std::string &fl_name() const;
|
||||
|
||||
void set_fl_iteration_num(uint64_t fl_iteration_num);
|
||||
uint64_t fl_iteration_num() const;
|
||||
|
||||
void set_client_epoch_num(uint64_t client_epoch_num);
|
||||
uint64_t client_epoch_num() const;
|
||||
|
||||
void set_client_batch_size(uint64_t client_batch_size);
|
||||
uint64_t client_batch_size() const;
|
||||
|
||||
private:
|
||||
PSContext()
|
||||
: ps_enabled_(false),
|
||||
|
@ -80,6 +93,12 @@ class PSContext {
|
|||
uint32_t server_num_;
|
||||
std::string scheduler_host_;
|
||||
uint16_t scheduler_port_;
|
||||
|
||||
// Members for federated learning.
|
||||
std::string fl_name_;
|
||||
uint64_t fl_iteration_num_;
|
||||
uint64_t client_epoch_num_;
|
||||
uint64_t client_batch_size_;
|
||||
};
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,223 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/server/collective_ops_impl.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
void CollectiveOpsImpl::Initialize(const std::shared_ptr<core::ServerNode> &server_node) {
|
||||
MS_EXCEPTION_IF_NULL(server_node);
|
||||
server_node_ = server_node;
|
||||
local_rank_ = server_node_->rank_id();
|
||||
server_num_ = PSContext::instance()->initial_server_num();
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size_t count) {
|
||||
int ret = memcpy_s(recvbuff, count * sizeof(T), sendbuff, count * sizeof(T));
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t rank_size = server_num_;
|
||||
uint32_t local_rank_ = server_node_->rank_id();
|
||||
size_t chunk_size = count / rank_size;
|
||||
size_t remainder_size = count % rank_size;
|
||||
std::vector<size_t> chunk_sizes(rank_size, chunk_size);
|
||||
// The rest of the data should be assigned to each chunk.
|
||||
for (size_t i = 0; i < remainder_size; i++) {
|
||||
chunk_sizes[i]++;
|
||||
}
|
||||
// Store offsets to get every data chunk's address.
|
||||
std::vector<size_t> chunk_offset;
|
||||
for (size_t i = 0; i < rank_size; i++) {
|
||||
size_t ofs =
|
||||
std::accumulate(chunk_sizes.begin(), chunk_sizes.begin() + i, static_cast<size_t>(0), std::plus<size_t>());
|
||||
chunk_offset.push_back(ofs);
|
||||
}
|
||||
|
||||
T *output_buff = reinterpret_cast<T *>(recvbuff);
|
||||
uint32_t send_to_rank = (local_rank_ + 1) % rank_size;
|
||||
uint32_t recv_from_rank = (local_rank_ - 1 + rank_size) % rank_size;
|
||||
MS_LOG(DEBUG) << "AllReduce count:" << count << ", rank_size:" << rank_size << ", local_rank_:" << local_rank_
|
||||
<< ", chunk_size:" << chunk_size << ", remainder_size:" << remainder_size
|
||||
<< ", chunk_sizes:" << chunk_sizes << ", send_to_rank:" << send_to_rank
|
||||
<< ", recv_from_rank:" << recv_from_rank;
|
||||
|
||||
// Ring ReduceScatter.
|
||||
MS_LOG(DEBUG) << "Start Ring ReduceScatter.";
|
||||
std::unique_ptr<T[]> tmp_recv_chunk = std::make_unique<T[]>(chunk_sizes[0]);
|
||||
for (size_t i = 0; i < rank_size - 1; i++) {
|
||||
// Step 1: Async send data to next rank.
|
||||
size_t send_chunk_index = (local_rank_ - i + rank_size) % rank_size;
|
||||
T *send_chunk = output_buff + chunk_offset[send_chunk_index];
|
||||
auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, send_to_rank, send_chunk,
|
||||
chunk_sizes[send_chunk_index] * sizeof(T));
|
||||
// Step 2: Async receive data to next rank and wait until it's done.
|
||||
size_t recv_chunk_index = (local_rank_ - i - 1 + rank_size) % rank_size;
|
||||
T *recv_chunk = output_buff + chunk_offset[recv_chunk_index];
|
||||
MS_LOG(DEBUG) << "Ring ReduceScatter send_to_rank:" << send_to_rank << ", recv_from_rank:" << recv_from_rank
|
||||
<< ", send count:" << chunk_sizes[send_chunk_index]
|
||||
<< ", recv count:" << chunk_sizes[recv_chunk_index] << ", iteration:" << i;
|
||||
|
||||
std::shared_ptr<std::vector<unsigned char>> recv_str;
|
||||
auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, recv_from_rank, &recv_str);
|
||||
if (!server_node_->CollectiveWait(recv_req_id, 1)) {
|
||||
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
|
||||
return false;
|
||||
}
|
||||
memcpy_s(tmp_recv_chunk.get(), chunk_sizes[recv_chunk_index] * sizeof(T), recv_str->data(), recv_str->size());
|
||||
|
||||
// Step 3: Reduce the data so we can overlap the time cost of send.
|
||||
for (size_t j = 0; j < chunk_sizes[recv_chunk_index]; j++) {
|
||||
recv_chunk[j] += tmp_recv_chunk[j];
|
||||
}
|
||||
// Step 4: Wait until send is done.
|
||||
if (!server_node_->Wait(send_req_id, 1)) {
|
||||
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << "End Ring ReduceScatter.";
|
||||
|
||||
// Ring AllGather.
|
||||
MS_LOG(DEBUG) << "Start Ring AllGather.";
|
||||
for (size_t i = 0; i < rank_size - 1; i++) {
|
||||
size_t send_chunk_index = (local_rank_ - i + 1 + rank_size) % rank_size;
|
||||
T *send_chunk = output_buff + chunk_offset[send_chunk_index];
|
||||
auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, send_to_rank, send_chunk,
|
||||
chunk_sizes[send_chunk_index] * sizeof(T));
|
||||
size_t recv_chunk_index = (local_rank_ - i + rank_size) % rank_size;
|
||||
T *recv_chunk = output_buff + chunk_offset[recv_chunk_index];
|
||||
MS_LOG(DEBUG) << "Ring AllGather send_to_rank:" << send_to_rank << ", recv_from_rank:" << recv_from_rank
|
||||
<< ", send count:" << chunk_sizes[send_chunk_index]
|
||||
<< ", recv count:" << chunk_sizes[recv_chunk_index] << ", iteration:" << i;
|
||||
|
||||
std::shared_ptr<std::vector<unsigned char>> recv_str;
|
||||
auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, recv_from_rank, &recv_str);
|
||||
|
||||
if (!server_node_->CollectiveWait(recv_req_id, 1)) {
|
||||
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
|
||||
return false;
|
||||
}
|
||||
memcpy_s(recv_chunk, chunk_sizes[recv_chunk_index] * sizeof(T), recv_str->data(), recv_str->size());
|
||||
if (!server_node_->Wait(send_req_id, 1)) {
|
||||
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << "End Ring AllGather.";
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *recvbuff, size_t count) {
|
||||
uint32_t rank_size = server_num_;
|
||||
uint32_t local_rank_ = server_node_->rank_id();
|
||||
MS_LOG(DEBUG) << "Reduce Broadcast AllReduce rank_size:" << rank_size << ", local_rank_:" << local_rank_
|
||||
<< ", count:" << count;
|
||||
int ret = memcpy_s(recvbuff, count * sizeof(T), sendbuff, count * sizeof(T));
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
|
||||
return false;
|
||||
}
|
||||
T *output_buff = reinterpret_cast<T *>(recvbuff);
|
||||
// Reduce data to rank 0 process.
|
||||
MS_LOG(DEBUG) << "Start Reduce to rank 0 process.";
|
||||
if (local_rank_ == 0) {
|
||||
std::unique_ptr<T[]> tmp_recv_buff = std::make_unique<T[]>(count);
|
||||
for (uint32_t i = 1; i < rank_size; i++) {
|
||||
std::shared_ptr<std::vector<unsigned char>> recv_str;
|
||||
MS_LOG(DEBUG) << "Reduce rank 0 receive from rank " << i;
|
||||
auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, i, &recv_str);
|
||||
if (!server_node_->CollectiveWait(recv_req_id, 1)) {
|
||||
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
|
||||
return false;
|
||||
}
|
||||
memcpy_s(tmp_recv_buff.get(), count * sizeof(T), recv_str->data(), recv_str->size());
|
||||
for (size_t j = 0; j < count; j++) {
|
||||
output_buff[j] += tmp_recv_buff[j];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Reduce send data to rank 0 process.";
|
||||
auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, 0, sendbuff, count * sizeof(T));
|
||||
if (!server_node_->Wait(send_req_id, 1)) {
|
||||
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << "End Reduce.";
|
||||
|
||||
// Broadcast data to not 0 rank process.
|
||||
MS_LOG(DEBUG) << "Start broadcast from rank 0 to other processes.";
|
||||
if (local_rank_ == 0) {
|
||||
for (uint32_t i = 1; i < rank_size; i++) {
|
||||
MS_LOG(DEBUG) << "Broadcast data to process " << i;
|
||||
auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, i, output_buff, count * sizeof(T));
|
||||
if (!server_node_->Wait(send_req_id, 1)) {
|
||||
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Broadcast receive from rank 0.";
|
||||
std::shared_ptr<std::vector<unsigned char>> recv_str;
|
||||
auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, 0, &recv_str);
|
||||
if (!server_node_->CollectiveWait(recv_req_id, 1)) {
|
||||
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
|
||||
return false;
|
||||
}
|
||||
memcpy_s(output_buff, count * sizeof(T), recv_str->data(), recv_str->size());
|
||||
}
|
||||
MS_LOG(DEBUG) << "End broadcast.";
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool CollectiveOpsImpl::AllReduce(const void *sendbuff, void *recvbuff, size_t count) {
|
||||
// The collective communication API does not support calling Send and Recv concurrently with multiple threads;
|
||||
std::unique_lock<std::mutex> lock(mtx_);
|
||||
if (sendbuff == nullptr || recvbuff == nullptr) {
|
||||
MS_LOG(ERROR) << "AllReduce sendbuff or recvbuff is nullptr.";
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t rank_size = server_num_;
|
||||
if (count >= rank_size) {
|
||||
return RingAllReduce<T>(sendbuff, recvbuff, count);
|
||||
} else {
|
||||
return ReduceBroadcastAllReduce<T>(sendbuff, recvbuff, count);
|
||||
}
|
||||
}
|
||||
|
||||
template bool CollectiveOpsImpl::RingAllReduce<float>(const void *sendbuff, void *recvbuff, size_t count);
|
||||
template bool CollectiveOpsImpl::RingAllReduce<size_t>(const void *sendbuff, void *recvbuff, size_t count);
|
||||
template bool CollectiveOpsImpl::RingAllReduce<int>(const void *sendbuff, void *recvbuff, size_t count);
|
||||
|
||||
template bool CollectiveOpsImpl::ReduceBroadcastAllReduce<float>(const void *sendbuff, void *recvbuff, size_t count);
|
||||
template bool CollectiveOpsImpl::ReduceBroadcastAllReduce<size_t>(const void *sendbuff, void *recvbuff, size_t count);
|
||||
template bool CollectiveOpsImpl::ReduceBroadcastAllReduce<int>(const void *sendbuff, void *recvbuff, size_t count);
|
||||
|
||||
template bool CollectiveOpsImpl::AllReduce<float>(const void *sendbuff, void *recvbuff, size_t count);
|
||||
template bool CollectiveOpsImpl::AllReduce<size_t>(const void *sendbuff, void *recvbuff, size_t count);
|
||||
template bool CollectiveOpsImpl::AllReduce<int>(const void *sendbuff, void *recvbuff, size_t count);
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,71 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_COLLECTIVE_OPS_IMPL_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_COLLECTIVE_OPS_IMPL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
#include "proto/ps.pb.h"
|
||||
#include "ps/ps_context.h"
|
||||
#include "ps/core/server_node.h"
|
||||
#include "ps/server/common.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
// CollectiveOpsImpl is the collective communication API of the server.
|
||||
// For now, it implements two AllReduce algorithms: RingAllReduce and BroadcastAllReduce. Elastic AllReduce is also
|
||||
// supported for the elastic scaling feature of the server.
|
||||
class CollectiveOpsImpl {
|
||||
public:
|
||||
static CollectiveOpsImpl &GetInstance() {
|
||||
static CollectiveOpsImpl instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void Initialize(const std::shared_ptr<core::ServerNode> &server_node);
|
||||
|
||||
template <typename T>
|
||||
bool AllReduce(const void *sendbuff, void *recvbuff, size_t count);
|
||||
|
||||
private:
|
||||
CollectiveOpsImpl() = default;
|
||||
~CollectiveOpsImpl() = default;
|
||||
CollectiveOpsImpl(const CollectiveOpsImpl &) = delete;
|
||||
CollectiveOpsImpl &operator=(const CollectiveOpsImpl &) = delete;
|
||||
|
||||
// Implementation of RingAllReduce.
|
||||
template <typename T>
|
||||
bool RingAllReduce(const void *sendbuff, void *recvbuff, size_t count);
|
||||
|
||||
// Implementation of BroadcastAllReduce.
|
||||
template <typename T>
|
||||
bool ReduceBroadcastAllReduce(const void *sendbuff, void *recvbuff, size_t count);
|
||||
|
||||
std::shared_ptr<core::ServerNode> server_node_;
|
||||
uint32_t local_rank_;
|
||||
uint32_t server_num_;
|
||||
|
||||
// The mutex to ensure that collective communication is threadsafe.
|
||||
std::mutex mtx_;
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_COLLECTIVE_OPS_IMPL_H_
|
|
@ -24,13 +24,17 @@
|
|||
#include <memory>
|
||||
#include <functional>
|
||||
#include "proto/ps.pb.h"
|
||||
#include "proto/fl.pb.h"
|
||||
#include "ir/anf.h"
|
||||
#include "utils/utils.h"
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "schema/fl_job_generated.h"
|
||||
#include "schema/cipher_generated.h"
|
||||
#include "ps/ps_context.h"
|
||||
#include "ps/core/communicator/http_message_handler.h"
|
||||
#include "ps/core/communicator/tcp_server.h"
|
||||
#include "ps/core/communicator/message_handler.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
|
@ -40,13 +44,15 @@ enum ServerMode { PARAMETER_SERVER = 0, FL_SERVER };
|
|||
enum CommType { HTTP = 0, TCP };
|
||||
enum AggregationType { FedAvg = 0, FedAdam, FedAdagarg, FedMeta, qffl, DenseGradAccum, SparseGradAccum };
|
||||
|
||||
using kernel::Address;
|
||||
using kernel::AddressPtr;
|
||||
using kernel::CPUKernel;
|
||||
using mindspore::kernel::Address;
|
||||
using mindspore::kernel::AddressPtr;
|
||||
using mindspore::kernel::CPUKernel;
|
||||
using FBBuilder = flatbuffers::FlatBufferBuilder;
|
||||
using TimeOutCb = std::function<void(void)>;
|
||||
using StopTimerCb = std::function<void(void)>;
|
||||
using FinishIterCb = std::function<void(void)>;
|
||||
using FinalizeCb = std::function<void(void)>;
|
||||
using MessageCallback = std::function<void(const std::shared_ptr<core::MessageHandler> &)>;
|
||||
|
||||
// Information about whether server kernel will reuse kernel node memory from the front end.
|
||||
// Key refers to the server kernel's parameter name, like "weights", "grad", "learning_rate".
|
||||
|
|
|
@ -0,0 +1,298 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/server/distributed_count_service.h"
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
void DistributedCountService::Initialize(const std::shared_ptr<core::ServerNode> &server_node,
|
||||
uint32_t counting_server_rank) {
|
||||
server_node_ = server_node;
|
||||
MS_EXCEPTION_IF_NULL(server_node_);
|
||||
|
||||
communicator_ =
|
||||
std::dynamic_pointer_cast<core::TcpCommunicator>(server_node_->GetOrCreateTcpComm("", 0, 0, 0, nullptr));
|
||||
MS_EXCEPTION_IF_NULL(communicator_);
|
||||
|
||||
local_rank_ = server_node_->rank_id();
|
||||
server_num_ = PSContext::instance()->initial_server_num();
|
||||
counting_server_rank_ = counting_server_rank;
|
||||
RegisterCallback();
|
||||
return;
|
||||
}
|
||||
|
||||
void DistributedCountService::RegisterCounter(const std::string &name, size_t global_threshold_count,
|
||||
const CounterHandlers &counter_handlers) {
|
||||
if (!counter_handlers.first_count_handler || !counter_handlers.last_count_handler) {
|
||||
MS_LOG(EXCEPTION) << "First count handler or last count handler is not set.";
|
||||
return;
|
||||
}
|
||||
if (global_threshold_count_.count(name) != 0) {
|
||||
MS_LOG(ERROR) << "Counter for " << name << " is already set.";
|
||||
return;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Rank " << local_rank_ << " register counter for " << name << " count:" << global_threshold_count;
|
||||
// If the server is the leader server, it needs to set the counter handlers and do the real counting.
|
||||
if (local_rank_ == counting_server_rank_) {
|
||||
global_current_count_[name] = {};
|
||||
global_threshold_count_[name] = global_threshold_count;
|
||||
mutex_[name];
|
||||
}
|
||||
counter_handlers_[name] = counter_handlers;
|
||||
return;
|
||||
}
|
||||
|
||||
bool DistributedCountService::Count(const std::string &name, const std::string &id) {
|
||||
MS_LOG(INFO) << "Rank " << local_rank_ << " reports count for " << name << " of " << id;
|
||||
if (local_rank_ == counting_server_rank_) {
|
||||
if (global_threshold_count_.count(name) == 0) {
|
||||
MS_LOG(ERROR) << "Counter for " << name << " is not registered.";
|
||||
return false;
|
||||
}
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_[name]);
|
||||
if (global_current_count_[name].size() >= global_threshold_count_[name]) {
|
||||
MS_LOG(ERROR) << "Count for " << name << " is already enough. Threshold count is "
|
||||
<< global_threshold_count_[name];
|
||||
return false;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Leader server increase count for " << name << " of " << id;
|
||||
global_current_count_[name].insert(id);
|
||||
TriggerCounterEvent(name);
|
||||
} else {
|
||||
// If this server is a follower server, it needs to send CountRequest to the leader server.
|
||||
CountRequest report_count_req;
|
||||
report_count_req.set_name(name);
|
||||
report_count_req.set_id(id);
|
||||
|
||||
std::shared_ptr<std::vector<unsigned char>> report_cnt_rsp_msg = nullptr;
|
||||
if (!communicator_->SendPbRequest(report_count_req, counting_server_rank_, core::TcpUserCommand::kCount,
|
||||
&report_cnt_rsp_msg)) {
|
||||
MS_LOG(ERROR) << "Sending reporting count message to leader server failed for " << name;
|
||||
return false;
|
||||
}
|
||||
|
||||
CountResponse count_rsp;
|
||||
count_rsp.ParseFromArray(report_cnt_rsp_msg->data(), report_cnt_rsp_msg->size());
|
||||
if (!count_rsp.result()) {
|
||||
MS_LOG(ERROR) << "Reporting count failed:" << count_rsp.reason();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool DistributedCountService::CountReachThreshold(const std::string &name) {
|
||||
MS_LOG(INFO) << "Rank " << local_rank_ << " query whether count reaches threshold for " << name;
|
||||
if (local_rank_ == counting_server_rank_) {
|
||||
if (global_threshold_count_.count(name) == 0) {
|
||||
MS_LOG(ERROR) << "Counter for " << name << " is not set.";
|
||||
return false;
|
||||
}
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_[name]);
|
||||
return global_current_count_[name].size() == global_threshold_count_[name];
|
||||
} else {
|
||||
CountReachThresholdRequest count_reach_threashold_req;
|
||||
count_reach_threashold_req.set_name(name);
|
||||
|
||||
std::shared_ptr<std::vector<unsigned char>> query_cnt_enough_rsp_msg = nullptr;
|
||||
if (!communicator_->SendPbRequest(count_reach_threashold_req, counting_server_rank_,
|
||||
core::TcpUserCommand::kReachThreshold, &query_cnt_enough_rsp_msg)) {
|
||||
MS_LOG(ERROR) << "Sending querying whether count reaches threshold message to leader server failed for " << name;
|
||||
return false;
|
||||
}
|
||||
|
||||
CountReachThresholdResponse count_reach_threashold_rsp;
|
||||
count_reach_threashold_rsp.ParseFromArray(query_cnt_enough_rsp_msg->data(), query_cnt_enough_rsp_msg->size());
|
||||
return count_reach_threashold_rsp.is_enough();
|
||||
}
|
||||
}
|
||||
|
||||
void DistributedCountService::ResetCounter(const std::string &name) {
|
||||
if (local_rank_ == counting_server_rank_) {
|
||||
MS_LOG(INFO) << "Leader server reset count for " << name;
|
||||
global_current_count_[name].clear();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void DistributedCountService::RegisterCallback() {
|
||||
if (local_rank_ == counting_server_rank_) {
|
||||
communicator_->RegisterMsgCallBack(
|
||||
"count", std::bind(&DistributedCountService::HandleCountRequest, this, std::placeholders::_1));
|
||||
communicator_->RegisterMsgCallBack(
|
||||
"countReachThreshold",
|
||||
std::bind(&DistributedCountService::HandleCountReachThresholdRequest, this, std::placeholders::_1));
|
||||
}
|
||||
|
||||
// The callback of first/last event must be set in both leader server and follower servers.
|
||||
communicator_->RegisterMsgCallBack(
|
||||
"counterEvent", std::bind(&DistributedCountService::HandleCounterEvent, this, std::placeholders::_1));
|
||||
}
|
||||
|
||||
void DistributedCountService::HandleCountRequest(const std::shared_ptr<core::MessageHandler> &message) {
|
||||
if (message == nullptr) {
|
||||
MS_LOG(ERROR) << "Message is nullptr.";
|
||||
return;
|
||||
}
|
||||
|
||||
CountRequest report_count_req;
|
||||
report_count_req.ParseFromArray(message->data(), message->len());
|
||||
const std::string &name = report_count_req.name();
|
||||
const std::string &id = report_count_req.id();
|
||||
|
||||
CountResponse count_rsp;
|
||||
std::unique_lock<std::mutex> lock(mutex_[name]);
|
||||
// If leader server has no counter for the name registered, return an error.
|
||||
if (global_threshold_count_.count(name) == 0) {
|
||||
std::string reason = "Counter for " + name + " is not registered.";
|
||||
count_rsp.set_result(false);
|
||||
count_rsp.set_reason(reason);
|
||||
MS_LOG(ERROR) << reason;
|
||||
communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(), message);
|
||||
return;
|
||||
}
|
||||
|
||||
// If leader server already has enough count for the name, return an error.
|
||||
if (global_current_count_[name].size() >= global_threshold_count_[name]) {
|
||||
std::string reason =
|
||||
"Count for " + name + " is already enough. Threshold count is " + std::to_string(global_threshold_count_[name]);
|
||||
count_rsp.set_result(false);
|
||||
count_rsp.set_reason(reason);
|
||||
MS_LOG(ERROR) << reason;
|
||||
communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(), message);
|
||||
return;
|
||||
}
|
||||
|
||||
// Insert the id for the counter, which means the count for the name is increased.
|
||||
MS_LOG(INFO) << "Leader server increase count for " << name << " of " << id;
|
||||
global_current_count_[name].insert(id);
|
||||
TriggerCounterEvent(name);
|
||||
count_rsp.set_result(true);
|
||||
count_rsp.set_reason("success");
|
||||
communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(), message);
|
||||
return;
|
||||
}
|
||||
|
||||
void DistributedCountService::HandleCountReachThresholdRequest(const std::shared_ptr<core::MessageHandler> &message) {
|
||||
if (message == nullptr) {
|
||||
MS_LOG(ERROR) << "Message is nullptr.";
|
||||
return;
|
||||
}
|
||||
|
||||
CountReachThresholdRequest count_reach_threashold_req;
|
||||
count_reach_threashold_req.ParseFromArray(message->data(), message->len());
|
||||
const std::string &name = count_reach_threashold_req.name();
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_[name]);
|
||||
if (global_threshold_count_.count(name) == 0) {
|
||||
MS_LOG(ERROR) << "Counter for " << name << " is not registered.";
|
||||
return;
|
||||
}
|
||||
|
||||
CountReachThresholdResponse count_reach_threashold_rsp;
|
||||
count_reach_threashold_rsp.set_is_enough(global_current_count_[name].size() == global_threshold_count_[name]);
|
||||
communicator_->SendResponse(count_reach_threashold_rsp.SerializeAsString().data(),
|
||||
count_reach_threashold_rsp.SerializeAsString().size(), message);
|
||||
return;
|
||||
}
|
||||
|
||||
void DistributedCountService::HandleCounterEvent(const std::shared_ptr<core::MessageHandler> &message) {
|
||||
if (message == nullptr) {
|
||||
MS_LOG(ERROR) << "Message is nullptr.";
|
||||
return;
|
||||
}
|
||||
|
||||
// Respond as soon as possible so the leader server won't wait for each follower servers to finish calling the
|
||||
// callbacks.
|
||||
std::string couter_event_rsp_msg = "success";
|
||||
communicator_->SendResponse(couter_event_rsp_msg.data(), couter_event_rsp_msg.size(), message);
|
||||
|
||||
CounterEvent counter_event;
|
||||
counter_event.ParseFromArray(message->data(), message->len());
|
||||
const auto &type = counter_event.type();
|
||||
const auto &name = counter_event.name();
|
||||
|
||||
MS_LOG(INFO) << "Rank " << local_rank_ << " do counter event " << type << " for " << name;
|
||||
if (type == CounterEventType::FIRST_CNT) {
|
||||
counter_handlers_[name].first_count_handler(message);
|
||||
} else if (type == CounterEventType::LAST_CNT) {
|
||||
counter_handlers_[name].last_count_handler(message);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "DistributedCountService event type " << type << " is invalid.";
|
||||
return;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void DistributedCountService::TriggerCounterEvent(const std::string &name) {
|
||||
MS_LOG(INFO) << "Current count for " << name << " is " << global_current_count_[name].size()
|
||||
<< ", threshold count is " << global_threshold_count_[name];
|
||||
// The threshold count may be 1 so the first and last count event should be both activated.
|
||||
if (global_current_count_[name].size() == 1) {
|
||||
TriggerFirstCountEvent(name);
|
||||
}
|
||||
if (global_current_count_[name].size() == global_threshold_count_[name]) {
|
||||
TriggerLastCountEvent(name);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void DistributedCountService::TriggerFirstCountEvent(const std::string &name) {
|
||||
MS_LOG(INFO) << "Activating first count event for " << name;
|
||||
CounterEvent first_count_event;
|
||||
first_count_event.set_type(CounterEventType::FIRST_CNT);
|
||||
first_count_event.set_name(name);
|
||||
|
||||
// Broadcast to all follower servers.
|
||||
for (uint32_t i = 1; i < server_num_; i++) {
|
||||
if (!communicator_->SendPbRequest(first_count_event, i, core::TcpUserCommand::kCounterEvent)) {
|
||||
MS_LOG(ERROR) << "Activating first count event to server " << i << " failed.";
|
||||
return;
|
||||
}
|
||||
}
|
||||
// Leader server directly calls the callback.
|
||||
counter_handlers_[name].first_count_handler(nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
void DistributedCountService::TriggerLastCountEvent(const std::string &name) {
|
||||
MS_LOG(INFO) << "Activating last count event for " << name;
|
||||
CounterEvent last_count_event;
|
||||
last_count_event.set_type(CounterEventType::LAST_CNT);
|
||||
last_count_event.set_name(name);
|
||||
|
||||
// Broadcast to all follower servers.
|
||||
for (uint32_t i = 1; i < server_num_; i++) {
|
||||
if (!communicator_->SendPbRequest(last_count_event, i, core::TcpUserCommand::kCounterEvent)) {
|
||||
MS_LOG(ERROR) << "Activating last count event to server " << i << " failed.";
|
||||
return;
|
||||
}
|
||||
}
|
||||
// Leader server directly calls the callback.
|
||||
counter_handlers_[name].last_count_handler(nullptr);
|
||||
return;
|
||||
}
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,126 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_COUNT_SERVICE_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_COUNT_SERVICE_H_
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include "proto/ps.pb.h"
|
||||
#include "ps/server/common.h"
|
||||
#include "ps/core/server_node.h"
|
||||
#include "ps/core/communicator/tcp_communicator.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
// The callbacks for the first count and last count event.
|
||||
typedef struct {
|
||||
MessageCallback first_count_handler;
|
||||
MessageCallback last_count_handler;
|
||||
} CounterHandlers;
|
||||
|
||||
// DistributedCountService is used for counting in the server cluster dimension. It's used for counting of rounds,
|
||||
// aggregation counting, etc.
|
||||
|
||||
// The counting could be called by any server, but only one server has the information
|
||||
// of the cluster count and we mark this server as the counting server. Other servers must communicate with this
|
||||
// counting server to increase/query count number.
|
||||
|
||||
// On the first count or last count event, DistributedCountService on the counting server triggers the event on other
|
||||
// servers by sending counter event commands. This is for the purpose of keeping server cluster's consistency.
|
||||
class DistributedCountService {
|
||||
public:
|
||||
static DistributedCountService &GetInstance() {
|
||||
static DistributedCountService instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
// Initialize counter service with the server node because communication is needed.
|
||||
void Initialize(const std::shared_ptr<core::ServerNode> &server_node, uint32_t counting_server_rank);
|
||||
|
||||
// Register counter to the counting server for the name with its threshold count in server cluster dimension and
|
||||
// first/last count event callbacks.
|
||||
void RegisterCounter(const std::string &name, size_t global_threshold_count, const CounterHandlers &counter_handlers);
|
||||
|
||||
// Report a count to the counting server. Parameter 'id' is in case of repeated counting.
|
||||
bool Count(const std::string &name, const std::string &id);
|
||||
|
||||
// Query whether the count reaches the threshold count for the name. If the count is the same as the threshold count,
|
||||
// this method returns true.
|
||||
bool CountReachThreshold(const std::string &name);
|
||||
|
||||
// Reset the count of the name to 0.
|
||||
void ResetCounter(const std::string &name);
|
||||
|
||||
// Returns the server rank because in some cases the callers use this rank as the 'id' for method
|
||||
// Count.
|
||||
uint32_t local_rank() { return local_rank_; }
|
||||
|
||||
private:
|
||||
DistributedCountService() = default;
|
||||
~DistributedCountService() = default;
|
||||
DistributedCountService(const DistributedCountService &) = delete;
|
||||
DistributedCountService &operator=(const DistributedCountService &) = delete;
|
||||
|
||||
// Register callbacks of the counting server to handle messages sent by the other servers.
|
||||
void RegisterCallback();
|
||||
|
||||
// Callback for the reporting count message from other servers. Only counting server will call this method.
|
||||
void HandleCountRequest(const std::shared_ptr<core::MessageHandler> &message);
|
||||
|
||||
// Callback for the querying whether threshold count is reached message from other servers. Only counting
|
||||
// server will call this method.
|
||||
void HandleCountReachThresholdRequest(const std::shared_ptr<core::MessageHandler> &message);
|
||||
|
||||
// Callback for the first/last event message from the counting server. Only other servers will call this
|
||||
// method.
|
||||
void HandleCounterEvent(const std::shared_ptr<core::MessageHandler> &message);
|
||||
|
||||
// Call the callbacks when the first/last count event is triggered.
|
||||
void TriggerCounterEvent(const std::string &name);
|
||||
void TriggerFirstCountEvent(const std::string &name);
|
||||
void TriggerLastCountEvent(const std::string &name);
|
||||
|
||||
// Members for the communication between counting server and other servers.
|
||||
std::shared_ptr<core::ServerNode> server_node_;
|
||||
std::shared_ptr<core::TcpCommunicator> communicator_;
|
||||
uint32_t local_rank_;
|
||||
uint32_t server_num_;
|
||||
|
||||
// Only one server will be set to do the real counting.
|
||||
uint32_t counting_server_rank_;
|
||||
|
||||
// Key: name, e.g, startFLJob, updateModel, push.
|
||||
// Value: a set of id without repeatation because each work may report multiple times.
|
||||
std::unordered_map<std::string, std::set<std::string>> global_current_count_;
|
||||
|
||||
// Key: name, e.g, StartFLJobCount.
|
||||
// Value: global threshold count in the server cluster dimension for this name.
|
||||
std::unordered_map<std::string, size_t> global_threshold_count_;
|
||||
|
||||
// First/last count event callbacks of the name.
|
||||
std::unordered_map<std::string, CounterHandlers> counter_handlers_;
|
||||
|
||||
// Because the count is increased/queried conccurently, we must ensure the operations are threadsafe.
|
||||
std::unordered_map<std::string, std::mutex> mutex_;
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_COUNT_SERVICE_H_
|
|
@ -0,0 +1,201 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/server/distributed_metadata_store.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
void DistributedMetadataStore::Initialize(const std::shared_ptr<core::ServerNode> &server_node) {
|
||||
server_node_ = server_node;
|
||||
MS_EXCEPTION_IF_NULL(server_node);
|
||||
|
||||
communicator_ =
|
||||
std::dynamic_pointer_cast<core::TcpCommunicator>(server_node_->GetOrCreateTcpComm("", 0, 0, 0, nullptr));
|
||||
MS_EXCEPTION_IF_NULL(communicator_);
|
||||
|
||||
local_rank_ = server_node_->rank_id();
|
||||
server_num_ = PSContext::instance()->initial_server_num();
|
||||
|
||||
InitHashRing();
|
||||
RegisterCallback();
|
||||
return;
|
||||
}
|
||||
|
||||
void DistributedMetadataStore::RegisterMetadata(const std::string &name, const PBMetadata &meta) {
|
||||
if (router_ == nullptr) {
|
||||
MS_LOG(ERROR) << "The consistent hash ring is not initialized yet.";
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t stored_rank = router_->Find(name);
|
||||
if (local_rank_ == stored_rank) {
|
||||
if (metadata_.count(name) != 0) {
|
||||
MS_LOG(ERROR) << "The metadata for " << name << " is already registered.";
|
||||
return;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Rank " << local_rank_ << " register storage for metadata " << name;
|
||||
metadata_[name] = meta;
|
||||
mutex_[name];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void DistributedMetadataStore::ResetMetadata(const std::string &name) {
|
||||
if (router_ == nullptr) {
|
||||
MS_LOG(ERROR) << "The consistent hash ring is not initialized yet.";
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t stored_rank = router_->Find(name);
|
||||
if (local_rank_ == stored_rank) {
|
||||
if (metadata_.count(name) == 0) {
|
||||
MS_LOG(ERROR) << "The metadata for " << name << " is not registered.";
|
||||
return;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Rank " << local_rank_ << " reset metadata for " << name;
|
||||
std::unique_lock<std::mutex> lock(mutex_[name]);
|
||||
PBMetadata empty_meta;
|
||||
metadata_[name] = empty_meta;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBMetadata &meta) {
|
||||
if (router_ == nullptr) {
|
||||
MS_LOG(ERROR) << "The consistent hash ring is not initialized yet.";
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t stored_rank = router_->Find(name);
|
||||
MS_LOG(INFO) << "Rank " << local_rank_ << " update value for " << name << " which is stored in rank " << stored_rank;
|
||||
if (local_rank_ == stored_rank) {
|
||||
if (!DoUpdateMetadata(name, meta)) {
|
||||
MS_LOG(ERROR) << "Updating meta data failed.";
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
PBMetadataWithName metadata_with_name;
|
||||
metadata_with_name.set_name(name);
|
||||
*metadata_with_name.mutable_metadata() = meta;
|
||||
if (!communicator_->SendPbRequest(metadata_with_name, stored_rank, core::TcpUserCommand::kUpdateMetadata)) {
|
||||
MS_LOG(ERROR) << "Sending updating metadata message to server " << stored_rank << " failed.";
|
||||
return;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
PBMetadata DistributedMetadataStore::GetMetadata(const std::string &name) {
|
||||
if (router_ == nullptr) {
|
||||
MS_LOG(ERROR) << "The consistent hash ring is not initialized yet.";
|
||||
return {};
|
||||
}
|
||||
|
||||
uint32_t stored_rank = router_->Find(name);
|
||||
MS_LOG(INFO) << "Rank " << local_rank_ << " get metadata for " << name << " which is stored in rank " << stored_rank;
|
||||
if (local_rank_ == stored_rank) {
|
||||
std::unique_lock<std::mutex> lock(mutex_[name]);
|
||||
return metadata_[name];
|
||||
} else {
|
||||
GetMetadataRequest get_metadata_req;
|
||||
get_metadata_req.set_name(name);
|
||||
PBMetadata get_metadata_rsp;
|
||||
|
||||
std::shared_ptr<std::vector<unsigned char>> get_meta_rsp_msg = nullptr;
|
||||
if (!communicator_->SendPbRequest(get_metadata_req, stored_rank, core::TcpUserCommand::kGetMetadata,
|
||||
&get_meta_rsp_msg)) {
|
||||
MS_LOG(ERROR) << "Sending getting metadata message to server " << stored_rank << " failed.";
|
||||
return get_metadata_rsp;
|
||||
}
|
||||
get_metadata_rsp.ParseFromArray(get_meta_rsp_msg->data(), get_meta_rsp_msg->size());
|
||||
return get_metadata_rsp;
|
||||
}
|
||||
}
|
||||
|
||||
void DistributedMetadataStore::InitHashRing() {
|
||||
router_ = std::make_shared<ConsistentHashRing>(32);
|
||||
MS_EXCEPTION_IF_NULL(router_);
|
||||
for (uint32_t i = 0; i < server_num_; i++) {
|
||||
bool ret = router_->Insert(i);
|
||||
if (!ret) {
|
||||
MS_LOG(EXCEPTION) << "Add node " << i << " to router of meta storage failed.";
|
||||
return;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void DistributedMetadataStore::RegisterCallback() {
|
||||
communicator_->RegisterMsgCallBack(
|
||||
"updateMetadata", std::bind(&DistributedMetadataStore::HandleUpdateMetadataRequest, this, std::placeholders::_1));
|
||||
communicator_->RegisterMsgCallBack(
|
||||
"getMetadata", std::bind(&DistributedMetadataStore::HandleGetMetadataRequest, this, std::placeholders::_1));
|
||||
return;
|
||||
}
|
||||
|
||||
void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr<core::MessageHandler> &message) {
|
||||
if (message == nullptr) {
|
||||
MS_LOG(ERROR) << "Message is nullptr.";
|
||||
return;
|
||||
}
|
||||
|
||||
PBMetadataWithName meta_with_name;
|
||||
meta_with_name.ParseFromArray(message->data(), message->len());
|
||||
const std::string &name = meta_with_name.name();
|
||||
MS_LOG(INFO) << "Update metadata for " << name;
|
||||
|
||||
std::string update_meta_rsp_msg;
|
||||
if (!DoUpdateMetadata(name, meta_with_name.metadata())) {
|
||||
update_meta_rsp_msg = "Updating meta data failed.";
|
||||
} else {
|
||||
update_meta_rsp_msg = "Success";
|
||||
}
|
||||
communicator_->SendResponse(update_meta_rsp_msg.data(), update_meta_rsp_msg.size(), message);
|
||||
return;
|
||||
}
|
||||
|
||||
void DistributedMetadataStore::HandleGetMetadataRequest(const std::shared_ptr<core::MessageHandler> &message) {
|
||||
if (message == nullptr) {
|
||||
MS_LOG(ERROR) << "Message is nullptr.";
|
||||
return;
|
||||
}
|
||||
|
||||
GetMetadataRequest get_metadata_req;
|
||||
get_metadata_req.ParseFromArray(message->data(), message->len());
|
||||
const std::string &name = get_metadata_req.name();
|
||||
MS_LOG(INFO) << "Getting metadata for " << name;
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_[name]);
|
||||
PBMetadata stored_meta = metadata_[name];
|
||||
std::string getting_meta_rsp_msg = stored_meta.SerializeAsString();
|
||||
communicator_->SendResponse(getting_meta_rsp_msg.data(), getting_meta_rsp_msg.size(), message);
|
||||
return;
|
||||
}
|
||||
|
||||
bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const PBMetadata &meta) {
|
||||
std::unique_lock<std::mutex> lock(mutex_[name]);
|
||||
metadata_[name] = meta;
|
||||
return true;
|
||||
}
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,101 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_META_STORE_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_META_STORE_H_
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include "proto/ps.pb.h"
|
||||
#include "ps/server/common.h"
|
||||
#include "ps/core/server_node.h"
|
||||
#include "ps/core/communicator/tcp_communicator.h"
|
||||
#include "ps/server/consistent_hash_ring.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
// This class is used for distributed metadata storage using consistent hash. All metadata is distributedly
|
||||
// stored in all servers. Caller doesn't need to know which server stores the metadata. It only needs to know what kind
|
||||
// of operations should be done to the metadata.
|
||||
|
||||
// The metadata stored in the server is in protobuffer format because it's easy for serializing and communicating. The
|
||||
// type of the protobuffer struct is decided by the caller using protobuffer's API.
|
||||
class DistributedMetadataStore {
|
||||
public:
|
||||
static DistributedMetadataStore &GetInstance() {
|
||||
static DistributedMetadataStore instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
// Initialize metadata storage with the server node because communication is needed.
|
||||
void Initialize(const std::shared_ptr<core::ServerNode> &server_node);
|
||||
|
||||
// Register metadata for the name with the initial value. This method should be only called once for each name.
|
||||
void RegisterMetadata(const std::string &name, const PBMetadata &meta);
|
||||
|
||||
// Reset the metadata value for the name.
|
||||
void ResetMetadata(const std::string &name);
|
||||
|
||||
// Update the metadata for the name.
|
||||
void UpdateMetadata(const std::string &name, const PBMetadata &meta);
|
||||
|
||||
// Get the metadata for the name.
|
||||
PBMetadata GetMetadata(const std::string &name);
|
||||
|
||||
private:
|
||||
DistributedMetadataStore() = default;
|
||||
~DistributedMetadataStore() = default;
|
||||
DistributedMetadataStore(const DistributedMetadataStore &) = delete;
|
||||
DistributedMetadataStore &operator=(const DistributedMetadataStore &) = delete;
|
||||
|
||||
// Initialize the consistent hash ring for distributed storage.
|
||||
void InitHashRing();
|
||||
|
||||
// Register callbacks for the server to handle update/get metadata messages from other servers.
|
||||
void RegisterCallback();
|
||||
|
||||
// Callback for updating metadata request sent to the server.
|
||||
void HandleUpdateMetadataRequest(const std::shared_ptr<core::MessageHandler> &message);
|
||||
|
||||
// Callback for getting metadata request sent to the server.
|
||||
void HandleGetMetadataRequest(const std::shared_ptr<core::MessageHandler> &message);
|
||||
|
||||
// Do updating metadata in the server where the metadata for the name is stored.
|
||||
bool DoUpdateMetadata(const std::string &name, const PBMetadata &meta);
|
||||
|
||||
// Members for the communication between servers.
|
||||
std::shared_ptr<core::ServerNode> server_node_;
|
||||
std::shared_ptr<core::TcpCommunicator> communicator_;
|
||||
uint32_t local_rank_;
|
||||
uint32_t server_num_;
|
||||
|
||||
// Consistent hash ring. This is used for DistributedMetadataStore to find which server node the meta data is stored.
|
||||
std::shared_ptr<ConsistentHashRing> router_;
|
||||
|
||||
// We store metadata which is serialized by ProtoBuffer so that data storage and data transmission API is easy to use.
|
||||
// Key: data name.
|
||||
// Value: ProtoBuffer Struct.
|
||||
std::unordered_map<std::string, PBMetadata> metadata_;
|
||||
|
||||
// Because the metadata is read/written conccurently, we must ensure the operations are threadsafe.
|
||||
std::unordered_map<std::string, std::mutex> mutex_;
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_META_STORE_H_
|
|
@ -169,7 +169,7 @@ bool Executor::HandleOverwriteWeightsByKey(const std::map<std::string, Address>
|
|||
}
|
||||
|
||||
AddressPtr Executor::HandlePull(const std::string ¶m_name) {
|
||||
MS_LOG(INFO) << "Handle blocking pull msg for parameter " << param_name;
|
||||
MS_LOG(INFO) << "Handle blocking pull message for parameter " << param_name;
|
||||
if (param_aggrs_.count(param_name) == 0) {
|
||||
MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server.";
|
||||
return nullptr;
|
||||
|
@ -193,11 +193,6 @@ AddressPtr Executor::HandlePull(const std::string ¶m_name) {
|
|||
return addr;
|
||||
}
|
||||
|
||||
std::map<std::string, AddressPtr> Executor::HandleAsyncGetModel() {
|
||||
std::unique_lock<std::mutex> lock(model_mutex_);
|
||||
return GetModel();
|
||||
}
|
||||
|
||||
std::map<std::string, AddressPtr> Executor::HandleGetWeightsByKey(const std::vector<std::string> ¶m_names) {
|
||||
std::map<std::string, AddressPtr> weights;
|
||||
for (const auto ¶m_name : param_names) {
|
||||
|
|
|
@ -63,10 +63,6 @@ class Executor {
|
|||
// asynchronously.
|
||||
bool HandleModelUpdateAsync(const std::map<std::string, UploadData> &feature_map);
|
||||
|
||||
// Called in asynchronous federated learning training mode. Returns whole model in key-value where key refers to the
|
||||
// parameter name.
|
||||
std::map<std::string, AddressPtr> HandleAsyncGetModel();
|
||||
|
||||
// Forcibly overwrite specific weights in overwriteWeights message.
|
||||
bool HandleOverwriteWeightsByKey(const std::map<std::string, Address> &feature_map);
|
||||
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/server/iteration.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <numeric>
|
||||
#include "ps/server/model_store.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
Iteration::Iteration() : iteration_num_(1) { LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_); }
|
||||
|
||||
void Iteration::AddRound(const std::shared_ptr<Round> &round) {
|
||||
MS_EXCEPTION_IF_NULL(round);
|
||||
rounds_.push_back(round);
|
||||
}
|
||||
|
||||
void Iteration::InitRounds(const std::vector<std::shared_ptr<core::CommunicatorBase>> &communicators,
|
||||
const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb) {
|
||||
if (communicators.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Communicators for rounds is empty.";
|
||||
return;
|
||||
}
|
||||
|
||||
std::for_each(communicators.begin(), communicators.end(),
|
||||
[&](const std::shared_ptr<core::CommunicatorBase> &communicator) {
|
||||
for (auto &round : rounds_) {
|
||||
if (round == nullptr) {
|
||||
continue;
|
||||
}
|
||||
round->Initialize(communicator, timeout_cb, finish_iteration_cb);
|
||||
}
|
||||
});
|
||||
|
||||
// The time window for one iteration, which will be used in some round kernels.
|
||||
size_t iteration_time_window =
|
||||
std::accumulate(rounds_.begin(), rounds_.end(), 0,
|
||||
[](size_t total, const std::shared_ptr<Round> &round) { return total + round->time_window(); });
|
||||
LocalMetaStore::GetInstance().put_value(kCtxTotalTimeoutDuration, iteration_time_window);
|
||||
return;
|
||||
}
|
||||
|
||||
void Iteration::ProceedToNextIter() {
|
||||
iteration_num_ = LocalMetaStore::GetInstance().curr_iter_num();
|
||||
// Store the model for each iteration.
|
||||
const auto &model = Executor::GetInstance().GetModel();
|
||||
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
|
||||
|
||||
for (auto &round : rounds_) {
|
||||
round->Reset();
|
||||
}
|
||||
|
||||
iteration_num_++;
|
||||
LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
|
||||
MS_LOG(INFO) << "Proceed to next iteration:" << iteration_num_ << "\n";
|
||||
}
|
||||
|
||||
const std::vector<std::shared_ptr<Round>> &Iteration::rounds() { return rounds_; }
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_ITERATION_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_ITERATION_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "ps/core/communicator/communicator_base.h"
|
||||
#include "ps/server/common.h"
|
||||
#include "ps/server/round.h"
|
||||
#include "ps/server/local_meta_store.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
// In server's logic, Iteration is the minimum execution unit. For each execution, it consists of multiple kinds of
|
||||
// Rounds, only after all the rounds are finished, this iteration is considered as completed.
|
||||
class Iteration {
|
||||
public:
|
||||
Iteration();
|
||||
~Iteration() = default;
|
||||
|
||||
// Add a round for the iteration. This method will be called multiple times for each round.
|
||||
void AddRound(const std::shared_ptr<Round> &round);
|
||||
|
||||
// Initialize all the rounds in the iteration.
|
||||
void InitRounds(const std::vector<std::shared_ptr<core::CommunicatorBase>> &communicators,
|
||||
const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb);
|
||||
|
||||
// The server proceeds to the next iteration only after the last iteration finishes.
|
||||
void ProceedToNextIter();
|
||||
|
||||
const std::vector<std::shared_ptr<Round>> &rounds();
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<Round>> rounds_;
|
||||
|
||||
// Server's current iteration number.
|
||||
size_t iteration_num_;
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_ITERATION_H_
|
|
@ -0,0 +1,127 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/server/kernel/round/round_kernel.h"
|
||||
#include <mutex>
|
||||
#include <queue>
|
||||
#include <chrono>
|
||||
#include <thread>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
RoundKernel::RoundKernel() : name_(""), current_count_(0), required_count_(0), error_reason_("") {
|
||||
release_thread_ = std::thread([&]() {
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> release_lock(release_mtx_);
|
||||
// Detect whether there's any data needs to be released every 100 milliseconds.
|
||||
if (heap_data_to_release_.empty()) {
|
||||
release_lock.unlock();
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
continue;
|
||||
}
|
||||
|
||||
AddressPtr addr_ptr = heap_data_to_release_.front();
|
||||
heap_data_to_release_.pop();
|
||||
release_lock.unlock();
|
||||
|
||||
std::unique_lock<std::mutex> heap_data_lock(heap_data_mtx_);
|
||||
if (heap_data_.count(addr_ptr) == 0) {
|
||||
MS_LOG(ERROR) << "The data is not stored.";
|
||||
continue;
|
||||
}
|
||||
// Manually release unique_ptr data.
|
||||
heap_data_[addr_ptr].reset(nullptr);
|
||||
heap_data_.erase(heap_data_.find(addr_ptr));
|
||||
}
|
||||
});
|
||||
release_thread_.detach();
|
||||
}
|
||||
|
||||
void RoundKernel::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message) { return; }
|
||||
|
||||
void RoundKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) { return; }
|
||||
|
||||
void RoundKernel::StopTimer() {
|
||||
if (stop_timer_cb_) {
|
||||
stop_timer_cb_();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void RoundKernel::FinishIteration() {
|
||||
if (finish_iteration_cb_) {
|
||||
finish_iteration_cb_();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void RoundKernel::Release(AddressPtr addr_ptr) {
|
||||
if (addr_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "Data to be released is empty.";
|
||||
return;
|
||||
}
|
||||
std::unique_lock<std::mutex> lock(release_mtx_);
|
||||
heap_data_to_release_.push(addr_ptr);
|
||||
return;
|
||||
}
|
||||
|
||||
void RoundKernel::set_name(const std::string &name) { name_ = name; }
|
||||
|
||||
void RoundKernel::set_stop_timer_cb(StopTimerCb timer_stopper) { stop_timer_cb_ = timer_stopper; }
|
||||
|
||||
void RoundKernel::set_finish_iteration_cb(FinishIterCb finish_iteration_cb) {
|
||||
finish_iteration_cb_ = finish_iteration_cb;
|
||||
}
|
||||
|
||||
void RoundKernel::GenerateOutput(const std::vector<AddressPtr> &outputs, void *data, size_t len) {
|
||||
if (data == nullptr) {
|
||||
MS_LOG(ERROR) << "The data is nullptr.";
|
||||
return;
|
||||
}
|
||||
|
||||
if (outputs.empty()) {
|
||||
MS_LOG(ERROR) << "Generating output failed. Outputs size is empty.";
|
||||
return;
|
||||
}
|
||||
|
||||
std::unique_ptr<unsigned char[]> output_data = std::make_unique<unsigned char[]>(len);
|
||||
if (output_data == nullptr) {
|
||||
MS_LOG(ERROR) << "Output data is nullptr.";
|
||||
return;
|
||||
}
|
||||
|
||||
size_t dst_size = len;
|
||||
int ret = memcpy_s(output_data.get(), dst_size, data, len);
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
|
||||
return;
|
||||
}
|
||||
outputs[0]->addr = output_data.get();
|
||||
outputs[0]->size = len;
|
||||
|
||||
std::unique_lock<std::mutex> lock(heap_data_mtx_);
|
||||
heap_data_.insert(std::make_pair(outputs[0], std::move(output_data)));
|
||||
return;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,130 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <mutex>
|
||||
#include <queue>
|
||||
#include <utility>
|
||||
#include <chrono>
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "ps/server/common.h"
|
||||
#include "ps/server/local_meta_store.h"
|
||||
#include "ps/server/distributed_count_service.h"
|
||||
#include "ps/server/distributed_metadata_store.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
// RoundKernel contains the main logic of server handling messages from workers. One iteration has multiple round
|
||||
// kernels to represent the process. They receive and parse messages from the server communication module. After
|
||||
// handling these messages, round kernels allocate response data and send it back.
|
||||
|
||||
// For example, the main process of federated learning is:
|
||||
// startFLJob round->updateModel round->getModel round.
|
||||
class RoundKernel : virtual public CPUKernel {
|
||||
public:
|
||||
RoundKernel();
|
||||
virtual ~RoundKernel() = default;
|
||||
|
||||
// RoundKernel doesn't use InitKernel method of base class CPUKernel to initialize. So implementation of this
|
||||
// inherited method is empty.
|
||||
void InitKernel(const CNodePtr &kernel_node) override {}
|
||||
|
||||
// Initialize RoundKernel with threshold_count which means that for every iteration, this round needs threshold_count
|
||||
// messages.
|
||||
virtual void InitKernel(size_t threshold_count) = 0;
|
||||
|
||||
// Launch the round kernel logic to handle the message passed by the communication module.
|
||||
virtual bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) = 0;
|
||||
|
||||
// The callbacks when first message and last message for this round kernel is received.
|
||||
// These methods is called by class DistributedCountService and triggered by leader server(Rank 0).
|
||||
// virtual void OnFirstCountEvent(std::shared_ptr<core::MessageHandler> message);
|
||||
// virtual void OnLastCnt(std::shared_ptr<core::MessageHandler> message);
|
||||
|
||||
// Some rounds could be stateful in a iteration. Reset method resets the status of this round.
|
||||
virtual bool Reset() = 0;
|
||||
|
||||
// The counter event handlers for DistributedCountService.
|
||||
virtual void OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message);
|
||||
virtual void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message);
|
||||
|
||||
// Called when this round is finished. This round timer's Stop method will be called.
|
||||
void StopTimer();
|
||||
|
||||
// Called after this iteration(including all rounds) is finished. All rounds' Reset method will
|
||||
// be called.
|
||||
void FinishIteration();
|
||||
|
||||
// Release the response data allocated inside the round kernel.
|
||||
// Server framework must call this after the response data is sent back.
|
||||
void Release(AddressPtr addr_ptr);
|
||||
|
||||
// Set round kernel name, which could be used in round kernel's methods.
|
||||
void set_name(const std::string &name);
|
||||
|
||||
// Set callbacks to be called under certain triggered conditions.
|
||||
void set_stop_timer_cb(StopTimerCb timer_stopper);
|
||||
void set_finish_iteration_cb(FinishIterCb finish_iteration_cb);
|
||||
|
||||
protected:
|
||||
// Generating response data of this round. The data is allocated on the heap to ensure it's not released before sent
|
||||
// back to worker.
|
||||
void GenerateOutput(const std::vector<AddressPtr> &outputs, void *data, size_t len);
|
||||
|
||||
// Round kernel's name.
|
||||
std::string name_;
|
||||
|
||||
// The current received message count for this round in this iteration.
|
||||
size_t current_count_;
|
||||
|
||||
// The required received message count for this round in one iteration.
|
||||
size_t required_count_;
|
||||
|
||||
// The reason causes the error in this round kernel.
|
||||
std::string error_reason_;
|
||||
|
||||
StopTimerCb stop_timer_cb_;
|
||||
FinishIterCb finish_iteration_cb_;
|
||||
|
||||
// Members below are used for allocating and releasing response data on the heap.
|
||||
|
||||
// To ensure the performance, we use another thread to release data on the heap. So the operation on the data should
|
||||
// be threadsafe.
|
||||
std::thread release_thread_;
|
||||
|
||||
// Data needs to be released and its mutex;
|
||||
std::mutex release_mtx_;
|
||||
std::queue<AddressPtr> heap_data_to_release_;
|
||||
std::mutex heap_data_mtx_;
|
||||
std::unordered_map<AddressPtr, std::unique_ptr<unsigned char[]>> heap_data_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/server/kernel/round/round_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
RoundKernelFactory &RoundKernelFactory::GetInstance() {
|
||||
static RoundKernelFactory instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void RoundKernelFactory::Register(const std::string &name, RoundKernelCreator &&creator) {
|
||||
name_to_creator_map_[name] = creator;
|
||||
}
|
||||
|
||||
std::shared_ptr<RoundKernel> RoundKernelFactory::Create(const std::string &name) {
|
||||
if (name_to_creator_map_.count(name) == 0) {
|
||||
MS_LOG(ERROR) << "Round kernel " << name << " is not registered.";
|
||||
return nullptr;
|
||||
}
|
||||
auto kernel = name_to_creator_map_[name]();
|
||||
kernel->set_name(name);
|
||||
return kernel;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,62 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <unordered_map>
|
||||
#include "ps/server/common.h"
|
||||
#include "ps/server/kernel/round/round_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
using RoundKernelCreator = std::function<std::shared_ptr<RoundKernel>()>;
|
||||
// Kernel factory of round kernels.
|
||||
class RoundKernelFactory {
|
||||
public:
|
||||
static RoundKernelFactory &GetInstance();
|
||||
void Register(const std::string &name, RoundKernelCreator &&creator);
|
||||
std::shared_ptr<RoundKernel> Create(const std::string &name);
|
||||
|
||||
private:
|
||||
RoundKernelFactory() = default;
|
||||
~RoundKernelFactory() = default;
|
||||
RoundKernelFactory(const RoundKernelFactory &) = delete;
|
||||
RoundKernelFactory &operator=(const RoundKernelFactory &) = delete;
|
||||
|
||||
std::unordered_map<std::string, RoundKernelCreator> name_to_creator_map_;
|
||||
};
|
||||
|
||||
class RoundKernelRegister {
|
||||
public:
|
||||
RoundKernelRegister(const std::string &name, RoundKernelCreator &&creator) {
|
||||
RoundKernelFactory::GetInstance().Register(name, std::move(creator));
|
||||
}
|
||||
};
|
||||
|
||||
#define REG_ROUND_KERNEL(NAME, CLASS) \
|
||||
static_assert(std::is_base_of<RoundKernel, CLASS>::value, " must be base of RoundKernel"); \
|
||||
static const RoundKernelRegister g_##NAME##_round_kernel_reg(#NAME, []() { return std::make_shared<CLASS>(); });
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_
|
|
@ -0,0 +1,192 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/server/kernel/round/start_fl_job_kernel.h"
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
void StartFLJobKernel::InitKernel(size_t) {
|
||||
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
|
||||
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||
}
|
||||
|
||||
executor_ = &Executor::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(executor_);
|
||||
if (!executor_->initialized()) {
|
||||
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
|
||||
return;
|
||||
}
|
||||
|
||||
PBMetadata devices_metas;
|
||||
DistributedMetadataStore::GetInstance().RegisterMetadata(kCtxDeviceMetas, devices_metas);
|
||||
return;
|
||||
}
|
||||
|
||||
bool StartFLJobKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
MS_LOG(INFO) << "Launching StartFLJobKernel kernel.";
|
||||
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||
MS_LOG(ERROR) << "inputs or outputs size is invalid.";
|
||||
return false;
|
||||
}
|
||||
void *req_data = inputs[0]->addr;
|
||||
const std::shared_ptr<FBBuilder> &fbb = std::make_shared<FBBuilder>();
|
||||
if (fbb == nullptr || req_data == nullptr) {
|
||||
MS_LOG(ERROR) << "FBBuilder builder or req_data is nullptr.";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ReachThresholdForStartFLJob(fbb)) {
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return false;
|
||||
}
|
||||
|
||||
const schema::RequestFLJob *start_fl_job_req = flatbuffers::GetRoot<schema::RequestFLJob>(req_data);
|
||||
DeviceMeta device_meta = CreateDeviceMetadata(start_fl_job_req);
|
||||
if (!ReadyForStartFLJob(fbb, device_meta)) {
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return false;
|
||||
}
|
||||
// If calling ReportCount before ReadyForStartFLJob, the result will be inconsistent if the device is not selected.
|
||||
if (!CountForStartFLJob(fbb, start_fl_job_req)) {
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return false;
|
||||
}
|
||||
|
||||
StartFLJob(fbb, device_meta);
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
bool StartFLJobKernel::Reset() {
|
||||
MS_LOG(INFO) << "Starting fl job kernel reset!";
|
||||
StopTimer();
|
||||
DistributedCountService::GetInstance().ResetCounter(name_);
|
||||
DistributedMetadataStore::GetInstance().ResetMetadata(kCtxDeviceMetas);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool StartFLJobKernel::ReachThresholdForStartFLJob(const std::shared_ptr<FBBuilder> &fbb) {
|
||||
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||
std::string reason = "Current amount for startFLJob has reached the threshold. Please startFLJob later.";
|
||||
BuildStartFLJobRsp(fbb, schema::ResponseCode_OutOfTime, reason, false,
|
||||
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
||||
MS_LOG(ERROR) << reason;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
DeviceMeta StartFLJobKernel::CreateDeviceMetadata(const schema::RequestFLJob *start_fl_job_req) {
|
||||
std::string fl_name = start_fl_job_req->fl_name()->str();
|
||||
std::string fl_id = start_fl_job_req->fl_id()->str();
|
||||
int data_size = start_fl_job_req->data_size();
|
||||
MS_LOG(INFO) << "DeviceMeta fl_name:" << fl_name << ", fl_id:" << fl_id << ", data_size:" << data_size;
|
||||
|
||||
DeviceMeta device_meta;
|
||||
device_meta.set_fl_name(fl_name);
|
||||
device_meta.set_fl_id(fl_id);
|
||||
device_meta.set_data_size(data_size);
|
||||
return device_meta;
|
||||
}
|
||||
|
||||
bool StartFLJobKernel::ReadyForStartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta) {
|
||||
bool ret = true;
|
||||
std::string reason = "";
|
||||
if (device_meta.data_size() < 1) {
|
||||
reason = "FL job data size is not enough.";
|
||||
ret = false;
|
||||
}
|
||||
if (!ret) {
|
||||
BuildStartFLJobRsp(fbb, schema::ResponseCode_NotSelected, reason, false,
|
||||
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
||||
MS_LOG(ERROR) << reason;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool StartFLJobKernel::CountForStartFLJob(const std::shared_ptr<FBBuilder> &fbb,
|
||||
const schema::RequestFLJob *start_fl_job_req) {
|
||||
if (!DistributedCountService::GetInstance().Count(name_, start_fl_job_req->fl_id()->str())) {
|
||||
std::string reason = "startFLJob counting failed.";
|
||||
BuildStartFLJobRsp(fbb, schema::ResponseCode_OutOfTime, reason, false,
|
||||
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void StartFLJobKernel::StartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta) {
|
||||
PBMetadata metadata;
|
||||
*metadata.mutable_device_meta() = device_meta;
|
||||
DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxDeviceMetas, metadata);
|
||||
|
||||
std::map<std::string, AddressPtr> feature_maps = executor_->GetModel();
|
||||
BuildStartFLJobRsp(fbb, schema::ResponseCode_SUCCEED, "success", true,
|
||||
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_), feature_maps);
|
||||
return;
|
||||
}
|
||||
|
||||
void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
|
||||
const std::string &reason, const bool is_selected,
|
||||
const std::string &next_req_time,
|
||||
std::map<std::string, AddressPtr> feature_maps) {
|
||||
auto fbs_reason = fbb->CreateString(reason);
|
||||
auto fbs_next_req_time = fbb->CreateString(next_req_time);
|
||||
auto fbs_fl_name = fbb->CreateString(PSContext::instance()->fl_name());
|
||||
|
||||
schema::FLPlanBuilder fl_plan_builder(*(fbb.get()));
|
||||
fl_plan_builder.add_fl_name(fbs_fl_name);
|
||||
fl_plan_builder.add_iterations(PSContext::instance()->fl_iteration_num());
|
||||
fl_plan_builder.add_epochs(PSContext::instance()->client_epoch_num());
|
||||
fl_plan_builder.add_mini_batch(PSContext::instance()->client_batch_size());
|
||||
auto fbs_fl_plan = fl_plan_builder.Finish();
|
||||
|
||||
std::vector<flatbuffers::Offset<schema::FeatureMap>> fbs_feature_maps;
|
||||
for (auto feature_map : feature_maps) {
|
||||
auto fbs_weight_fullname = fbb->CreateString(feature_map.first);
|
||||
auto fbs_weight_data =
|
||||
fbb->CreateVector(reinterpret_cast<float *>(feature_map.second->addr), feature_map.second->size / sizeof(float));
|
||||
auto fbs_feature_map = schema::CreateFeatureMap(*(fbb.get()), fbs_weight_fullname, fbs_weight_data);
|
||||
fbs_feature_maps.push_back(fbs_feature_map);
|
||||
}
|
||||
auto fbs_feature_maps_vector = fbb->CreateVector(fbs_feature_maps);
|
||||
|
||||
schema::ResponseFLJobBuilder rsp_fl_job_builder(*(fbb.get()));
|
||||
rsp_fl_job_builder.add_retcode(retcode);
|
||||
rsp_fl_job_builder.add_reason(fbs_reason);
|
||||
rsp_fl_job_builder.add_iteration(LocalMetaStore::GetInstance().curr_iter_num());
|
||||
rsp_fl_job_builder.add_is_selected(is_selected);
|
||||
rsp_fl_job_builder.add_next_req_time(fbs_next_req_time);
|
||||
rsp_fl_job_builder.add_fl_plan_config(fbs_fl_plan);
|
||||
rsp_fl_job_builder.add_feature_map(fbs_feature_maps_vector);
|
||||
auto rsp_fl_job = rsp_fl_job_builder.Finish();
|
||||
fbb->Finish(rsp_fl_job);
|
||||
return;
|
||||
}
|
||||
|
||||
REG_ROUND_KERNEL(startFLJob, StartFLJobKernel)
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,74 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_START_FL_JOB_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_START_FL_JOB_KERNEL_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "ps/server/common.h"
|
||||
#include "ps/server/executor.h"
|
||||
#include "ps/server/kernel/round/round_kernel.h"
|
||||
#include "ps/server/kernel/round/round_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
class StartFLJobKernel : public RoundKernel {
|
||||
public:
|
||||
StartFLJobKernel() = default;
|
||||
~StartFLJobKernel() override = default;
|
||||
|
||||
void InitKernel(size_t threshold_count) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
bool Reset() override;
|
||||
|
||||
private:
|
||||
// Returns whether the startFLJob count of this iteration has reached the threshold.
|
||||
bool ReachThresholdForStartFLJob(const std::shared_ptr<FBBuilder> &fbb);
|
||||
|
||||
// The metadata of device will be stored and queried in updateModel round.
|
||||
DeviceMeta CreateDeviceMetadata(const schema::RequestFLJob *start_fl_job_req);
|
||||
|
||||
// Returns whether the request is valid for startFLJob.For now, the condition is simple. We will add more conditions
|
||||
// to device in later versions.
|
||||
bool ReadyForStartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta);
|
||||
|
||||
// Distributed count service counts for startFLJob.
|
||||
bool CountForStartFLJob(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestFLJob *start_fl_job_req);
|
||||
|
||||
void StartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta);
|
||||
|
||||
// Build response for startFLJob round no matter success or failure.
|
||||
void BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
|
||||
const std::string &reason, const bool is_selected, const std::string &next_req_time,
|
||||
std::map<std::string, AddressPtr> feature_maps = {});
|
||||
|
||||
// The executor is for getting the initial model for startFLJob request.
|
||||
Executor *executor_;
|
||||
|
||||
// The time window of one iteration.
|
||||
size_t iteration_time_window_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_START_FL_JOB_KERNEL_H_
|
|
@ -14,30 +14,29 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/server/local_meta_storage.h"
|
||||
#include <string>
|
||||
#include "ps/server/local_meta_store.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
void LocalMetaStorage::remove_value(const std::string &name) {
|
||||
void LocalMetaStore::remove_value(const std::string &name) {
|
||||
std::unique_lock<std::mutex> lock(mtx_);
|
||||
if (key_to_meta_.count(name) != 0) {
|
||||
key_to_meta_.erase(key_to_meta_.find(name));
|
||||
}
|
||||
}
|
||||
|
||||
bool LocalMetaStorage::has_value(const std::string &name) {
|
||||
bool LocalMetaStore::has_value(const std::string &name) {
|
||||
std::unique_lock<std::mutex> lock(mtx_);
|
||||
return key_to_meta_.count(name) != 0;
|
||||
}
|
||||
|
||||
void LocalMetaStorage::set_curr_iter_num(size_t num) {
|
||||
void LocalMetaStore::set_curr_iter_num(size_t num) {
|
||||
std::unique_lock<std::mutex> lock(mtx_);
|
||||
curr_iter_num_ = num;
|
||||
}
|
||||
|
||||
const size_t LocalMetaStorage::curr_iter_num() {
|
||||
const size_t LocalMetaStore::curr_iter_num() {
|
||||
std::unique_lock<std::mutex> lock(mtx_);
|
||||
return curr_iter_num_;
|
||||
}
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORAGE_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORAGE_H_
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORE_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORE_H_
|
||||
|
||||
#include <any>
|
||||
#include <mutex>
|
||||
|
@ -26,13 +26,13 @@
|
|||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
// LocalMetaStorage class is used for metadata storage of this server process.
|
||||
// LocalMetaStore class is used for metadata storage of this server process.
|
||||
// For example, the current iteration number, time windows for round kernels, etc.
|
||||
// LocalMetaStorage is threadsafe.
|
||||
class LocalMetaStorage {
|
||||
// LocalMetaStore is threadsafe.
|
||||
class LocalMetaStore {
|
||||
public:
|
||||
static LocalMetaStorage &GetInstance() {
|
||||
static LocalMetaStorage instance;
|
||||
static LocalMetaStore &GetInstance() {
|
||||
static LocalMetaStore instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
|
@ -43,7 +43,7 @@ class LocalMetaStorage {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
const T &value(const std::string &name) {
|
||||
T value(const std::string &name) {
|
||||
std::unique_lock<std::mutex> lock(mtx_);
|
||||
try {
|
||||
T value = std::any_cast<T>(key_to_meta_[name]);
|
||||
|
@ -71,10 +71,10 @@ class LocalMetaStorage {
|
|||
const size_t curr_iter_num();
|
||||
|
||||
private:
|
||||
LocalMetaStorage() = default;
|
||||
~LocalMetaStorage() = default;
|
||||
LocalMetaStorage(const LocalMetaStorage &) = delete;
|
||||
LocalMetaStorage &operator=(const LocalMetaStorage &) = delete;
|
||||
LocalMetaStore() = default;
|
||||
~LocalMetaStore() = default;
|
||||
LocalMetaStore(const LocalMetaStore &) = delete;
|
||||
LocalMetaStore &operator=(const LocalMetaStore &) = delete;
|
||||
|
||||
// key_to_meta_ stores metadata with key-value format.
|
||||
std::unordered_map<std::string, std::any> key_to_meta_;
|
||||
|
@ -85,4 +85,4 @@ class LocalMetaStorage {
|
|||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORAGE_H_
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORE_H_
|
|
@ -0,0 +1,144 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/server/model_store.h"
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ps/server/executor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
void ModelStore::Init(uint32_t max_count) {
|
||||
if (!Executor::GetInstance().initialized()) {
|
||||
MS_LOG(EXCEPTION) << "Server's executor must be initialized before model storage.";
|
||||
return;
|
||||
}
|
||||
|
||||
max_model_count_ = max_count;
|
||||
iteration_to_model_[kInitIterationNum] = AssignNewModelMemory();
|
||||
model_size_ = ComputeModelSize();
|
||||
}
|
||||
|
||||
bool ModelStore::StoreModelByIterNum(size_t iteration, const std::map<std::string, AddressPtr> &new_model) {
|
||||
if (iteration_to_model_.count(iteration) != 0) {
|
||||
MS_LOG(WARNING) << "Model for iteration " << iteration << " is already stored";
|
||||
return false;
|
||||
}
|
||||
if (new_model.empty()) {
|
||||
MS_LOG(ERROR) << "Model feature map is empty.";
|
||||
return false;
|
||||
}
|
||||
|
||||
std::shared_ptr<MemoryRegister> memory_register;
|
||||
if (iteration_to_model_.size() < max_model_count_) {
|
||||
// If iteration_to_model_.size() is not max_model_count_, need to assign new memory for the model.
|
||||
memory_register = AssignNewModelMemory();
|
||||
if (memory_register == nullptr) {
|
||||
MS_LOG(ERROR) << "Memory for the new model is nullptr.";
|
||||
return false;
|
||||
}
|
||||
|
||||
iteration_to_model_[iteration] = memory_register;
|
||||
} else {
|
||||
// If iteration_to_model_ size is already max_model_count_, we need to replace earliest model with the newest model.
|
||||
memory_register = iteration_to_model_.begin()->second;
|
||||
if (memory_register == nullptr) {
|
||||
MS_LOG(ERROR) << "Earliest model is nullptr.";
|
||||
return false;
|
||||
}
|
||||
iteration_to_model_.erase(iteration_to_model_.begin());
|
||||
}
|
||||
|
||||
// Copy new model data to the the stored model.
|
||||
auto &stored_model = memory_register->addresses();
|
||||
for (const auto &weight : new_model) {
|
||||
const std::string &weight_name = weight.first;
|
||||
if (stored_model.count(weight_name) != 0) {
|
||||
MS_LOG(ERROR) << "The stored model has no weight " << weight_name;
|
||||
continue;
|
||||
}
|
||||
|
||||
void *dst_addr = stored_model[weight_name]->addr;
|
||||
size_t dst_size = stored_model[weight_name]->size;
|
||||
void *src_addr = weight.second->addr;
|
||||
size_t src_size = weight.second->size;
|
||||
int ret = memcpy_s(dst_addr, dst_size, src_addr, src_size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
iteration_to_model_[iteration] = memory_register;
|
||||
return true;
|
||||
}
|
||||
|
||||
std::map<std::string, AddressPtr> ModelStore::GetModelByIterNum(size_t iteration) {
|
||||
std::map<std::string, AddressPtr> model = {};
|
||||
if (iteration_to_model_.count(iteration) == 0) {
|
||||
MS_LOG(ERROR) << "Model for iteration " << iteration << " is not stored.";
|
||||
return model;
|
||||
}
|
||||
model = iteration_to_model_[iteration]->addresses();
|
||||
return model;
|
||||
}
|
||||
|
||||
const std::map<size_t, std::shared_ptr<MemoryRegister>> &ModelStore::iteration_to_model() const {
|
||||
return iteration_to_model_;
|
||||
}
|
||||
|
||||
size_t ModelStore::model_size() const { return model_size_; }
|
||||
|
||||
std::shared_ptr<MemoryRegister> ModelStore::AssignNewModelMemory() {
|
||||
std::map<std::string, AddressPtr> model = Executor::GetInstance().GetModel();
|
||||
if (model.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Model feature map is empty.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Assign new memory for the model.
|
||||
std::shared_ptr<MemoryRegister> memory_register = std::make_shared<MemoryRegister>();
|
||||
for (const auto &weight : model) {
|
||||
const std::string weight_name = weight.first;
|
||||
size_t weight_size = weight.second->size;
|
||||
auto weight_data = std::make_unique<char[]>(weight_size);
|
||||
if (weight_data == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Assign memory for weight failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
memory_register->RegisterArray(weight_name, &weight_data, weight_size);
|
||||
}
|
||||
return memory_register;
|
||||
}
|
||||
|
||||
size_t ModelStore::ComputeModelSize() {
|
||||
if (iteration_to_model_.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Calculating model size failed: model for iteration 0 is not stored yet. ";
|
||||
return 0;
|
||||
}
|
||||
|
||||
const auto &model = iteration_to_model_[kInitIterationNum];
|
||||
MS_EXCEPTION_IF_NULL(model);
|
||||
size_t model_size = std::accumulate(model->addresses().begin(), model->addresses().end(), static_cast<size_t>(0),
|
||||
[](size_t s, const auto &weight) { return s + weight.second->size; });
|
||||
MS_LOG(INFO) << "Model size in byte is " << model_size;
|
||||
return model_size;
|
||||
}
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,78 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_MODEL_STORE_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_MODEL_STORE_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "ps/server/common.h"
|
||||
#include "ps/server/memory_register.h"
|
||||
#include "ps/server/executor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
// The initial iteration number is 0 in server.
|
||||
constexpr size_t kInitIterationNum = 0;
|
||||
|
||||
// Server framework use ModelStore to store and query models.
|
||||
// ModelStore stores multiple models because worker could get models of the previous iterations.
|
||||
class ModelStore {
|
||||
public:
|
||||
static ModelStore &GetInstance() {
|
||||
static ModelStore instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
// Initialize ModelStore with max count of models need to be stored.
|
||||
void Init(uint32_t max_count = 3);
|
||||
|
||||
// Store the model of the given iteration. The model is acquired from Executor. If the current model count is already
|
||||
// max_model_count_, the earliest model will be replaced.
|
||||
bool StoreModelByIterNum(size_t iteration, const std::map<std::string, AddressPtr> &model);
|
||||
|
||||
// Get model of the given iteration.
|
||||
std::map<std::string, AddressPtr> GetModelByIterNum(size_t iteration);
|
||||
|
||||
// Returns all models stored in ModelStore.
|
||||
const std::map<size_t, std::shared_ptr<MemoryRegister>> &iteration_to_model() const;
|
||||
|
||||
// Returns the model size, which could be calculated at the initializing phase.
|
||||
size_t model_size() const;
|
||||
|
||||
private:
|
||||
ModelStore() : max_model_count_(0), model_size_(0), iteration_to_model_({}) {}
|
||||
~ModelStore() = default;
|
||||
ModelStore(const ModelStore &) = delete;
|
||||
ModelStore &operator=(const ModelStore &) = delete;
|
||||
|
||||
// To store multiple models, new memory must assigned. The max memory size assigned for models is max_model_count_ *
|
||||
// model_size_.
|
||||
std::shared_ptr<MemoryRegister> AssignNewModelMemory();
|
||||
|
||||
// Calculate the model size. This method should be called after iteration_to_model_ is initialized.
|
||||
size_t ComputeModelSize();
|
||||
|
||||
size_t max_model_count_;
|
||||
size_t model_size_;
|
||||
std::map<size_t, std::shared_ptr<MemoryRegister>> iteration_to_model_;
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_MODEL_STORE_H_
|
|
@ -25,15 +25,15 @@
|
|||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
bool ParameterAggregator::Init(const CNodePtr &cnode, size_t required_count) {
|
||||
bool ParameterAggregator::Init(const CNodePtr &cnode, size_t threshold_count) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
memory_register_ = std::make_shared<MemoryRegister>();
|
||||
MS_EXCEPTION_IF_NULL(memory_register_);
|
||||
|
||||
required_push_count_ = required_count;
|
||||
required_push_count_ = threshold_count;
|
||||
// The required_pull_count_ is the count for Pull, which should be the same as required_push_count_.
|
||||
// required_pull_count_ normally used in parameter server training mode.
|
||||
required_pull_count_ = required_count;
|
||||
required_pull_count_ = threshold_count;
|
||||
|
||||
MS_LOG(DEBUG) << "Start initializing kernels for " << AnfAlgo::GetCNodeName(cnode);
|
||||
InitAggregationKernels(cnode);
|
||||
|
|
|
@ -61,8 +61,8 @@ class ParameterAggregator {
|
|||
~ParameterAggregator() = default;
|
||||
|
||||
// Initialize ParameterAggregator with a cnode. This cnode is normally a optimizer kernel for now.
|
||||
// The parameter required_count helps ParameterAggregator to judge the current status if it's stateful.
|
||||
bool Init(const CNodePtr &cnode, size_t required_count = 0);
|
||||
// The parameter threshold_count helps ParameterAggregator to judge the current status if it's stateful.
|
||||
bool Init(const CNodePtr &cnode, size_t threshold_count = 0);
|
||||
|
||||
// Update old data stored in ParameterAggregator with new data.
|
||||
// The data could have many meanings: weights, gradients, learning_rate, momentum, etc.
|
||||
|
|
|
@ -0,0 +1,139 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/server/round.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
Round::Round(const std::string &name, bool check_timeout, size_t time_window, bool check_count, size_t threshold_count)
|
||||
: name_(name),
|
||||
check_timeout_(check_timeout),
|
||||
time_window_(time_window),
|
||||
check_count_(check_count),
|
||||
threshold_count_(threshold_count) {}
|
||||
|
||||
void Round::Initialize(const std::shared_ptr<core::CommunicatorBase> &communicator, TimeOutCb timeout_cb,
|
||||
FinishIterCb finish_iteration_cb) {
|
||||
MS_EXCEPTION_IF_NULL(communicator);
|
||||
communicator_ = communicator;
|
||||
|
||||
// Register callback for round kernel.
|
||||
communicator_->RegisterMsgCallBack(
|
||||
name_, [&](std::shared_ptr<core::MessageHandler> message) { LaunchRoundKernel(message); });
|
||||
|
||||
// Callback when the iteration is finished.
|
||||
finish_iteration_cb_ = [this, finish_iteration_cb](void) -> void {
|
||||
MS_LOG(INFO) << "Round " << name_ << " finished! Proceed to next iteration.";
|
||||
finish_iteration_cb();
|
||||
};
|
||||
|
||||
// Callback for finalizing the server. This can only be called once.
|
||||
finalize_cb_ = [&](void) -> void { communicator_->Stop(); };
|
||||
|
||||
if (check_timeout_) {
|
||||
iter_timer_ = std::make_shared<IterationTimer>();
|
||||
|
||||
// 1.Set the timeout callback for the timer.
|
||||
iter_timer_->SetTimeOutCallBack([this, timeout_cb](void) -> void {
|
||||
MS_LOG(INFO) << "Round " << name_ << " timeout! Proceed to next iteration.";
|
||||
timeout_cb();
|
||||
});
|
||||
|
||||
// 2.Stopping timer callback which will be set to the round kernel.
|
||||
stop_timer_cb_ = [&](void) -> void {
|
||||
MS_LOG(INFO) << "Round " << name_ << " kernel stops its timer.";
|
||||
iter_timer_->Stop();
|
||||
};
|
||||
}
|
||||
|
||||
// Set counter event callbacks for this round if the round kernel is stateful.
|
||||
if (check_count_) {
|
||||
auto first_count_handler = std::bind(&Round::OnFirstCountEvent, this, std::placeholders::_1);
|
||||
auto last_count_handler = std::bind(&Round::OnLastCountEvent, this, std::placeholders::_1);
|
||||
DistributedCountService::GetInstance().RegisterCounter(name_, threshold_count_,
|
||||
{first_count_handler, last_count_handler});
|
||||
}
|
||||
}
|
||||
|
||||
void Round::BindRoundKernel(const std::shared_ptr<kernel::RoundKernel> &kernel) {
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
kernel_ = kernel;
|
||||
kernel_->set_stop_timer_cb(stop_timer_cb_);
|
||||
kernel_->set_finish_iteration_cb(finish_iteration_cb_);
|
||||
return;
|
||||
}
|
||||
|
||||
void Round::LaunchRoundKernel(const std::shared_ptr<core::MessageHandler> &message) {
|
||||
if (message == nullptr) {
|
||||
MS_LOG(ERROR) << "Message is nullptr.";
|
||||
return;
|
||||
}
|
||||
|
||||
AddressPtr input = std::make_shared<Address>();
|
||||
AddressPtr output = std::make_shared<Address>();
|
||||
input->addr = message->data();
|
||||
input->size = message->len();
|
||||
bool ret = kernel_->Launch({input}, {}, {output});
|
||||
if (output->size == 0) {
|
||||
std::string reason = "The output of the round " + name_ + " is empty.";
|
||||
MS_LOG(WARNING) << reason;
|
||||
communicator_->SendResponse(reason.c_str(), reason.size(), message);
|
||||
return;
|
||||
}
|
||||
|
||||
// Must send response back no matter what value Launch method returns.
|
||||
if (!ret) {
|
||||
MS_LOG(WARNING) << "Launching round kernel of round " << name_ << " failed.";
|
||||
}
|
||||
communicator_->SendResponse(output->addr, output->size, message);
|
||||
kernel_->Release(output);
|
||||
return;
|
||||
}
|
||||
|
||||
void Round::Reset() { kernel_->Reset(); }
|
||||
|
||||
const std::string &Round::name() const { return name_; }
|
||||
|
||||
size_t Round::threshold_count() const { return threshold_count_; }
|
||||
|
||||
size_t Round::time_window() const { return time_window_; }
|
||||
|
||||
void Round::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &) {
|
||||
MS_LOG(INFO) << "Round " << name_ << " first count event is triggered.";
|
||||
// The timer starts only after the first count event is triggered by DistributedCountService.
|
||||
if (check_timeout_) {
|
||||
iter_timer_->Start(std::chrono::milliseconds(time_window_));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void Round::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) {
|
||||
MS_LOG(INFO) << "Round " << name_ << " last count event is triggered.";
|
||||
// Same as the first count event, the timer must be stopped by DistributedCountService.
|
||||
if (check_timeout_) {
|
||||
iter_timer_->Stop();
|
||||
}
|
||||
|
||||
// Some kernels override the OnLastCountEvent method.
|
||||
kernel_->OnLastCountEvent(message);
|
||||
return;
|
||||
}
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,95 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_ROUND_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_ROUND_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "ps/core/communicator/communicator_base.h"
|
||||
#include "ps/server/common.h"
|
||||
#include "ps/server/iteration_timer.h"
|
||||
#include "ps/server/distributed_count_service.h"
|
||||
#include "ps/server/kernel/round/round_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
// Round helps server to handle network round messages and launch round kernels. One iteration in server consists of
|
||||
// multiple rounds like startFLJob, updateModel, Push, Pull, etc. Some round kernels may be stateful because of counting
|
||||
// and timing. So Round helps register counter and timer so that the round kernels only need to focus on the logic.
|
||||
class Round {
|
||||
public:
|
||||
explicit Round(const std::string &name, bool check_timeout = true, size_t time_window = 3000,
|
||||
bool check_count = false, size_t threshold_count = 8);
|
||||
~Round() = default;
|
||||
|
||||
void Initialize(const std::shared_ptr<core::CommunicatorBase> &communicator, TimeOutCb timeout_cb,
|
||||
FinishIterCb finish_iteration_cb);
|
||||
|
||||
// Bind a round kernel to this Round. This method should be called after Initialize.
|
||||
void BindRoundKernel(const std::shared_ptr<kernel::RoundKernel> &kernel);
|
||||
|
||||
// This method is the callback which will be set to the communicator and called after the corresponding round message
|
||||
// is sent to the server.
|
||||
void LaunchRoundKernel(const std::shared_ptr<core::MessageHandler> &message);
|
||||
|
||||
// Round needs to be reset after each iteration is finished or its timer expires.
|
||||
void Reset();
|
||||
|
||||
const std::string &name() const;
|
||||
size_t threshold_count() const;
|
||||
size_t time_window() const;
|
||||
|
||||
private:
|
||||
// The callbacks which will be set to DistributedCounterService.
|
||||
void OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message);
|
||||
void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message);
|
||||
|
||||
std::string name_;
|
||||
|
||||
// Whether this round needs to use timer. Most rounds in federated learning with mobile devices scenario need to set
|
||||
// check_timeout_ to true.
|
||||
bool check_timeout_;
|
||||
|
||||
// The time window duration for this round when check_timeout_ is set to true.
|
||||
size_t time_window_;
|
||||
|
||||
// If check_count_ is true, it means the round has to do counting for every round message and the first/last count
|
||||
// event will be triggered.
|
||||
bool check_count_;
|
||||
|
||||
// The threshold count for this round when check_count_ is set to true. The logic of this round has to check whether
|
||||
// the round message count has reached threshold_count_.
|
||||
size_t threshold_count_;
|
||||
|
||||
std::shared_ptr<core::CommunicatorBase> communicator_;
|
||||
|
||||
// The round kernel for this Round.
|
||||
std::shared_ptr<kernel::RoundKernel> kernel_;
|
||||
|
||||
// Some rounds may need timer to eliminate the long tail effect.
|
||||
std::shared_ptr<IterationTimer> iter_timer_;
|
||||
|
||||
// The callbacks which will be set to the round kernel.
|
||||
StopTimerCb stop_timer_cb_;
|
||||
FinishIterCb finish_iteration_cb_;
|
||||
FinalizeCb finalize_cb_;
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_ROUND_H_
|
|
@ -31,7 +31,7 @@
|
|||
#include "utils/utils.h"
|
||||
#include "frontend/parallel/context.h"
|
||||
#include "debug/env_config_parser.h"
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#include "ps/ps_cache/ps_cache_manager.h"
|
||||
#endif
|
||||
|
||||
|
@ -307,7 +307,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
|
|||
}
|
||||
need_alloc_nodes.push_back(item);
|
||||
}
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
bool ps_cache_check = false;
|
||||
#endif
|
||||
for (auto &item : need_alloc_nodes) {
|
||||
|
@ -320,7 +320,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
|
|||
continue;
|
||||
}
|
||||
DeviceAddressPtr device_address = nullptr;
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
const std::string ¶m_name = item->fullname_with_scope();
|
||||
if (ps::ps_cache_instance.IsHashTable(param_name)) {
|
||||
MS_LOG(INFO) << "Parameter(" << param_name << ")"
|
||||
|
@ -1038,7 +1038,7 @@ DeviceAddressPtr KernelRuntime::AssignSingleOpLaunchMemory(size_t size, const st
|
|||
return device_address;
|
||||
}
|
||||
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph,
|
||||
AnfNodePtr *const first_cache_input_index,
|
||||
size_t *const first_cache_size) {
|
||||
|
|
|
@ -142,7 +142,7 @@ class KernelRuntime {
|
|||
void RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value, session::KernelGraph *graph);
|
||||
void AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, size_t output_idx);
|
||||
DeviceAddressPtr PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index);
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
void GetFirstPSEmbeddingCache(const session::KernelGraph *graph, AnfNodePtr *const first_cache_input_index,
|
||||
size_t *const first_cache_size);
|
||||
void CheckIfSupportPSEmbeddingCache(const session::KernelGraph *graph);
|
||||
|
|
|
@ -16,14 +16,14 @@
|
|||
|
||||
#include "runtime/device/kernel_runtime_manager.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#include "ps/ps_cache/ps_cache_manager.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
void KernelRuntimeManager::ClearRuntimeResource() {
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
||||
ps::ps_cache_instance.SyncEmbeddingTable();
|
||||
}
|
||||
|
@ -125,7 +125,7 @@ void KernelRuntimeManager::ReleaseKernelRuntime(const std::string &device_name,
|
|||
if (runtime == nullptr) {
|
||||
return;
|
||||
}
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
||||
ps::ps_cache_instance.SyncEmbeddingTable();
|
||||
}
|
||||
|
|
|
@ -0,0 +1,123 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
namespace mindspore.schema;
|
||||
|
||||
table CipherPublicParams {
|
||||
t:int;
|
||||
p:[ubyte];
|
||||
g:int;
|
||||
prime:[ubyte];
|
||||
dp_eps:float;
|
||||
dp_delta:float;
|
||||
dp_norm_clip:float;
|
||||
encrypt_type:int;
|
||||
}
|
||||
|
||||
table ClientPublicKeys {
|
||||
fl_id:string;
|
||||
c_pk:[ubyte];
|
||||
s_pk: [ubyte];
|
||||
}
|
||||
|
||||
table ClientShare {
|
||||
fl_id:string;
|
||||
share:[ubyte];
|
||||
index:int;
|
||||
}
|
||||
|
||||
table RequestExchangeKeys{
|
||||
fl_id:string;
|
||||
c_pk:[ubyte];
|
||||
s_pk:[ubyte];
|
||||
iteration:int;
|
||||
timestamp:string;
|
||||
}
|
||||
|
||||
table ResponseExchangeKeys{
|
||||
retcode:int;
|
||||
reason:string;
|
||||
next_req_time:string;
|
||||
iteration:int;
|
||||
}
|
||||
|
||||
table GetExchangeKeys{
|
||||
fl_id:string;
|
||||
iteration:int;
|
||||
timestamp:string;
|
||||
}
|
||||
|
||||
table ReturnExchangeKeys{
|
||||
retcode:int;
|
||||
iteration:int;
|
||||
remote_publickeys:[ClientPublicKeys];
|
||||
next_req_time:string;
|
||||
}
|
||||
|
||||
table RequestShareSecrets{
|
||||
fl_id:string;
|
||||
encrypted_shares:[ClientShare];
|
||||
iteration:int;
|
||||
timestamp:string;
|
||||
}
|
||||
|
||||
table ResponseShareSecrets{
|
||||
retcode:int;
|
||||
reason:string;
|
||||
next_req_time:string;
|
||||
iteration:int;
|
||||
}
|
||||
|
||||
table GetShareSecrets{
|
||||
fl_id:string;
|
||||
iteration:int;
|
||||
timestamp:string;
|
||||
}
|
||||
|
||||
table ReturnShareSecrets{
|
||||
retcode:int;
|
||||
iteration:int;
|
||||
encrypted_shares: [ClientShare];
|
||||
next_req_time:string;
|
||||
}
|
||||
|
||||
table GetClientList{
|
||||
fl_id:string;
|
||||
iteration:int;
|
||||
timestamp:string;
|
||||
}
|
||||
|
||||
table ReturnClientList{
|
||||
retcode:int;
|
||||
reason:string;
|
||||
clients:[string];
|
||||
iteration:int;
|
||||
next_req_time:string;
|
||||
}
|
||||
|
||||
table SendReconstructSecret{
|
||||
fl_id:string;
|
||||
reconstruct_secret_shares:[ClientShare];
|
||||
iteration:int;
|
||||
timestamp:string;
|
||||
}
|
||||
|
||||
table ReconstructSecret{
|
||||
retcode:int;
|
||||
reason:string;
|
||||
iteration:int;
|
||||
next_req_time:string;
|
||||
}
|
|
@ -0,0 +1,159 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
include "cipher.fbs";
|
||||
namespace mindspore.schema;
|
||||
|
||||
file_identifier "FLJ0";
|
||||
file_extension "fl";
|
||||
|
||||
enum ResponseCode: int {
|
||||
SUCCEED=200,
|
||||
SucNotReady=201,
|
||||
RepeatRequest=202,
|
||||
SucNotMatch=204,
|
||||
OutOfTime=300,
|
||||
NotSelected=301,
|
||||
RequestError=400,
|
||||
SystemError=500
|
||||
}
|
||||
|
||||
enum AggregationType:byte {FedAvg=0, FedAdam = 1, FedAdagrag=2, FedMeta=3, qffl=4}
|
||||
enum Metrics:byte {accuracy = 0, precision = 1, recall = 2, AUC = 3,f1=4, fbeta=5}
|
||||
enum EarlyStopType:byte {loss_diff = 0, loss_abs = 1, weight_diff = 2}
|
||||
|
||||
table Aggregation {
|
||||
type:AggregationType;
|
||||
weights:[float];
|
||||
}
|
||||
|
||||
table EarlyStop {
|
||||
early_stop_type:EarlyStopType;
|
||||
weight:float;
|
||||
rounds:int;
|
||||
}
|
||||
|
||||
table FeatureMap{
|
||||
weight_fullname:string;
|
||||
data:[float];
|
||||
}
|
||||
table RequestFLJob{
|
||||
fl_name:string;
|
||||
fl_id:string;
|
||||
iteration:int;
|
||||
data_size:int;
|
||||
timestamp:string;
|
||||
}
|
||||
table ResponseFLJob {
|
||||
retcode:int;
|
||||
reason:string;
|
||||
iteration:int;
|
||||
is_selected:bool = false;
|
||||
next_req_time:string;
|
||||
fl_plan_config:FLPlan;
|
||||
feature_map:[FeatureMap];
|
||||
timestamp:string;
|
||||
}
|
||||
|
||||
table FLPlan {
|
||||
fl_name:string;
|
||||
iterations:int;
|
||||
epochs:int;
|
||||
early_stop:EarlyStop;
|
||||
mini_batch:int;
|
||||
shuffle:bool = false;
|
||||
lr:float;
|
||||
aggregation:Aggregation;
|
||||
metrics:[Metrics];
|
||||
cipher:CipherPublicParams;
|
||||
}
|
||||
|
||||
table RequestUpdateModel{
|
||||
fl_name:string;
|
||||
fl_id:string;
|
||||
iteration:int;
|
||||
feature_map:[FeatureMap];
|
||||
timestamp:string;
|
||||
}
|
||||
table ResponseUpdateModel{
|
||||
retcode:int;
|
||||
reason:string;
|
||||
feature_map:[FeatureMap];
|
||||
next_req_time:string;
|
||||
timestamp:string;
|
||||
}
|
||||
|
||||
table RequestAsyncUpdateModel{
|
||||
fl_name:string;
|
||||
fl_id:string;
|
||||
iteration:int;
|
||||
data_size:int;
|
||||
feature_map:[FeatureMap];
|
||||
}
|
||||
table ResponseAsyncUpdateModel{
|
||||
retcode:int;
|
||||
reason:string;
|
||||
iteration:int;
|
||||
}
|
||||
|
||||
table RequestOverwriteWeightsByKey{
|
||||
iteration:int;
|
||||
feature_map:[FeatureMap];
|
||||
}
|
||||
|
||||
table ResponseOverwriteWeightsByKey{
|
||||
retcode:int;
|
||||
reason:string;
|
||||
}
|
||||
|
||||
table RequestGetModel{
|
||||
fl_name:string;
|
||||
iteration:int;
|
||||
timestamp:string;
|
||||
}
|
||||
table ResponseGetModel{
|
||||
retcode:int;
|
||||
reason:string;
|
||||
iteration:int;
|
||||
feature_map:[FeatureMap];
|
||||
timestamp:string;
|
||||
}
|
||||
|
||||
table RequestAsyncGetModel{
|
||||
fl_name:string;
|
||||
iteration:int;
|
||||
}
|
||||
table ResponseAsyncGetModel{
|
||||
retcode:int;
|
||||
reason:string;
|
||||
iteration:int;
|
||||
feature_map:[FeatureMap];
|
||||
}
|
||||
|
||||
table RequestGetWeightsByKey{
|
||||
iteration:int;
|
||||
weight_names:[string];
|
||||
}
|
||||
table ResponseGetWeightsByKey{
|
||||
retcode:int;
|
||||
reason:string;
|
||||
feature_map:[FeatureMap];
|
||||
}
|
||||
|
||||
// FeatureMapList refers to the whole trained model.
|
||||
table FeatureMapList {
|
||||
feature_map:[FeatureMap];
|
||||
}
|
Loading…
Reference in New Issue