From 12f95b51f445ef50192ae19f084db2d49e2767e5 Mon Sep 17 00:00:00 2001 From: ZPaC Date: Mon, 26 Apr 2021 14:36:19 +0800 Subject: [PATCH] Add server code part2 --- cmake/external_libs/flatbuffers.cmake | 2 +- cmake/package.cmake | 2 +- mindspore/ccsrc/CMakeLists.txt | 2 +- .../backend/kernel_compiler/CMakeLists.txt | 2 +- .../ccsrc/backend/session/ascend_session.cc | 2 +- .../ccsrc/backend/session/cpu_session.cc | 6 +- mindspore/ccsrc/backend/session/executor.cc | 2 +- .../ccsrc/backend/session/session_basic.cc | 4 +- .../ccsrc/backend/session/session_basic.h | 4 +- .../parallel/ops_info/gather_v2_p_info.cc | 6 +- .../frontend/parallel/ops_info/unique_info.cc | 6 +- .../frontend/parallel/ops_info/unique_info.h | 2 +- .../frontend/parallel/step_auto_parallel.cc | 4 +- .../ccsrc/frontend/parallel/step_parallel.cc | 4 +- .../ccsrc/minddata/dataset/CMakeLists.txt | 2 +- .../dataset/engine/cache/CMakeLists.txt | 3 +- mindspore/ccsrc/pipeline/jit/action.cc | 8 +- mindspore/ccsrc/pipeline/jit/init.cc | 2 +- mindspore/ccsrc/pipeline/jit/pass.cc | 4 +- mindspore/ccsrc/pipeline/jit/pipeline.cc | 10 +- mindspore/ccsrc/ps/CMakeLists.txt | 34 +- .../ps/core/communicator/tcp_communicator.h | 23 +- mindspore/ccsrc/ps/core/protos/fl.proto | 155 +++++++++ mindspore/ccsrc/ps/core/protos/ps.proto | 2 +- mindspore/ccsrc/ps/ps_cache/CMakeLists.txt | 2 +- mindspore/ccsrc/ps/ps_context.cc | 34 +- mindspore/ccsrc/ps/ps_context.h | 19 ++ .../ccsrc/ps/server/collective_ops_impl.cc | 223 +++++++++++++ .../ccsrc/ps/server/collective_ops_impl.h | 71 +++++ mindspore/ccsrc/ps/server/common.h | 12 +- .../ps/server/distributed_count_service.cc | 298 ++++++++++++++++++ .../ps/server/distributed_count_service.h | 126 ++++++++ .../ps/server/distributed_metadata_store.cc | 201 ++++++++++++ .../ps/server/distributed_metadata_store.h | 101 ++++++ mindspore/ccsrc/ps/server/executor.cc | 7 +- mindspore/ccsrc/ps/server/executor.h | 4 - mindspore/ccsrc/ps/server/iteration.cc | 76 +++++ mindspore/ccsrc/ps/server/iteration.h | 58 ++++ .../ps/server/kernel/round/round_kernel.cc | 127 ++++++++ .../ps/server/kernel/round/round_kernel.h | 130 ++++++++ .../kernel/round/round_kernel_factory.cc | 44 +++ .../kernel/round/round_kernel_factory.h | 62 ++++ .../kernel/round/start_fl_job_kernel.cc | 192 +++++++++++ .../server/kernel/round/start_fl_job_kernel.h | 74 +++++ ...al_meta_storage.cc => local_meta_store.cc} | 11 +- ...ocal_meta_storage.h => local_meta_store.h} | 26 +- mindspore/ccsrc/ps/server/model_store.cc | 144 +++++++++ mindspore/ccsrc/ps/server/model_store.h | 78 +++++ .../ccsrc/ps/server/parameter_aggregator.cc | 6 +- .../ccsrc/ps/server/parameter_aggregator.h | 4 +- mindspore/ccsrc/ps/server/round.cc | 139 ++++++++ mindspore/ccsrc/ps/server/round.h | 95 ++++++ .../ccsrc/runtime/device/kernel_runtime.cc | 8 +- .../ccsrc/runtime/device/kernel_runtime.h | 2 +- .../runtime/device/kernel_runtime_manager.cc | 6 +- mindspore/schema/cipher.fbs | 123 ++++++++ mindspore/schema/fl_job.fbs | 159 ++++++++++ 57 files changed, 2846 insertions(+), 107 deletions(-) create mode 100644 mindspore/ccsrc/ps/core/protos/fl.proto create mode 100644 mindspore/ccsrc/ps/server/collective_ops_impl.cc create mode 100644 mindspore/ccsrc/ps/server/collective_ops_impl.h create mode 100644 mindspore/ccsrc/ps/server/distributed_count_service.cc create mode 100644 mindspore/ccsrc/ps/server/distributed_count_service.h create mode 100644 mindspore/ccsrc/ps/server/distributed_metadata_store.cc create mode 100644 mindspore/ccsrc/ps/server/distributed_metadata_store.h create mode 100644 mindspore/ccsrc/ps/server/iteration.cc create mode 100644 mindspore/ccsrc/ps/server/iteration.h create mode 100644 mindspore/ccsrc/ps/server/kernel/round/round_kernel.cc create mode 100644 mindspore/ccsrc/ps/server/kernel/round/round_kernel.h create mode 100644 mindspore/ccsrc/ps/server/kernel/round/round_kernel_factory.cc create mode 100644 mindspore/ccsrc/ps/server/kernel/round/round_kernel_factory.h create mode 100644 mindspore/ccsrc/ps/server/kernel/round/start_fl_job_kernel.cc create mode 100644 mindspore/ccsrc/ps/server/kernel/round/start_fl_job_kernel.h rename mindspore/ccsrc/ps/server/{local_meta_storage.cc => local_meta_store.cc} (79%) rename mindspore/ccsrc/ps/server/{local_meta_storage.h => local_meta_store.h} (77%) create mode 100644 mindspore/ccsrc/ps/server/model_store.cc create mode 100644 mindspore/ccsrc/ps/server/model_store.h create mode 100644 mindspore/ccsrc/ps/server/round.cc create mode 100644 mindspore/ccsrc/ps/server/round.h create mode 100644 mindspore/schema/cipher.fbs create mode 100644 mindspore/schema/fl_job.fbs diff --git a/cmake/external_libs/flatbuffers.cmake b/cmake/external_libs/flatbuffers.cmake index 2f915b17aba..78288263c1d 100644 --- a/cmake/external_libs/flatbuffers.cmake +++ b/cmake/external_libs/flatbuffers.cmake @@ -35,7 +35,7 @@ function(ms_build_flatbuffers source_schema_files set(total_schema_dirs -I ${schema_dir} ${total_schema_dirs}) endforeach() - foreach(schema ${source_schema_files}) + foreach(schema IN LISTS ${source_schema_files}) get_filename_component(filename ${schema} NAME_WE) if(NOT ${generated_output_dir} STREQUAL "") set(generated_file ${generated_output_dir}/${filename}_generated.h) diff --git a/cmake/package.cmake b/cmake/package.cmake index 6350b7f0e61..32451a63a92 100644 --- a/cmake/package.cmake +++ b/cmake/package.cmake @@ -212,7 +212,7 @@ if(ENABLE_GPU) ) endif() -if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) +if(ENABLE_CPU AND NOT WIN32) install( TARGETS ps_cache DESTINATION ${INSTALL_LIB_DIR} diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index 370a87acfec..0a8c38886b7 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -373,7 +373,7 @@ elseif(CMAKE_SYSTEM_NAME MATCHES "Darwin") target_link_libraries(mindspore mindspore_gvar) target_link_libraries(_c_expression PRIVATE -Wl,-force_load mindspore mindspore_core -Wl,-noall_load) else() - if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) + if(ENABLE_CPU AND NOT WIN32) target_link_libraries(mindspore proto_input mindspore::protobuf mindspore::event mindspore::event_pthreads mindspore::event_openssl mindspore::json) target_link_libraries(mindspore -Wl,--no-as-needed mindspore::event_core ps_cache) diff --git a/mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt b/mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt index 593faa29d57..8fb245d80df 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt +++ b/mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt @@ -75,7 +75,7 @@ if(ENABLE_CPU) endif() endif() -if(NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) +if(NOT ENABLE_CPU OR WIN32) list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/apply_momentum_ps_kernel.cc") list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/embedding_look_up_proxy_kernel.cc") list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/embedding_look_up_ps_kernel.cc") diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index 7c3f8b9b273..58bce2be9a4 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -421,7 +421,7 @@ void AscendSession::LoadInputData(const std::shared_ptr &kernel_gra size = abstract::ShapeSize(shape_tmp) * abstract::TypeIdSize(tensor->data_type()); } if (input_node->isa() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) { -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) const std::string ¶m_name = input_node->fullname_with_scope(); if (ps::ps_cache_instance.IsHashTable(param_name)) { continue; diff --git a/mindspore/ccsrc/backend/session/cpu_session.cc b/mindspore/ccsrc/backend/session/cpu_session.cc index bd7f9466971..7051b8362fe 100644 --- a/mindspore/ccsrc/backend/session/cpu_session.cc +++ b/mindspore/ccsrc/backend/session/cpu_session.cc @@ -33,7 +33,7 @@ #include "debug/anf_ir_dump.h" #include "debug/dump_proto.h" #include "debug/data_dump/dump_json_parser.h" -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) #include "ps/util.h" #include "ps/ps_context.h" #endif @@ -74,7 +74,7 @@ void CPUSession::Reorder(std::vector *node_list) { AnfAlgo::ReorderPos void CPUSession::Optimize(const std::shared_ptr &kernel_graph) { auto optimizer = std::make_shared(); auto pm = std::make_shared(); -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); if (ms_context->get_param(MS_CTX_EXECUTION_MODE) != kPynativeMode && ps::PSContext::instance()->is_ps_mode()) { @@ -174,7 +174,7 @@ void CPUSession::PreExecuteGraph(const std::shared_ptr &kernel_grap MS_LOG(INFO) << "Bind input output address"; runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs); -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) InitPSParamAndOptim(kernel_graph, inputs); #endif } diff --git a/mindspore/ccsrc/backend/session/executor.cc b/mindspore/ccsrc/backend/session/executor.cc index 202410b1436..9b3997d4269 100644 --- a/mindspore/ccsrc/backend/session/executor.cc +++ b/mindspore/ccsrc/backend/session/executor.cc @@ -21,7 +21,7 @@ #include "utils/comm_manager.h" #include "utils/scoped_long_running.h" #include "pybind_api/ir/tensor_py.h" -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) #include "ps/ps_cache/ps_cache_manager.h" #endif diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 9dd63fa7e78..026eb785cf9 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -43,7 +43,7 @@ #include "debug/common.h" #include "utils/trace_base.h" #include "frontend/parallel/context.h" -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) #include "ps/ps_cache/ps_cache_manager.h" #include "ps/constants.h" #include "ps/util.h" @@ -2357,7 +2357,7 @@ void SessionBasic::DumpGraph(const std::shared_ptr &kernel_graph) { #endif } -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) void SessionBasic::InitPsWorker(const KernelGraphPtr &kernel_graph) { if (!ps::PSContext::instance()->is_worker()) { return; diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h index dfbfe89f884..4ba99bc1b5a 100644 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -244,7 +244,7 @@ class SessionBasic : public std::enable_shared_from_this { std::vector GetAllReduceSplitIndex(); virtual std::string GetCommWorldGroup() { return std::string(); } void DumpGraph(const std::shared_ptr &kernel_graph); -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) void CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) const; void GetBatchElements(const AnfNodePtr &kernel_node) const; void InitPsWorker(const KernelGraphPtr &kernel_graph); @@ -263,7 +263,7 @@ class SessionBasic : public std::enable_shared_from_this { #if !defined(_WIN32) && !defined(_WIN64) std::shared_ptr debugger_; #endif -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) bool initialized_ps_cache_{false}; #endif }; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc index c8abaf00ebf..766fff47ff8 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc @@ -25,7 +25,7 @@ #include "frontend/parallel/device_matrix.h" #include "frontend/parallel/graph_util/generate_graph.h" #include "frontend/parallel/context.h" -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) #include "ps/ps_cache/ps_cache_manager.h" #include "utils/ms_context.h" #endif @@ -160,7 +160,7 @@ Status GatherPInfo::GetAttrs() { if (std::find(inputs_shape_[1].begin(), inputs_shape_[1].end(), -1) != inputs_shape_[1].end()) { dynamic_shape_indices_ = true; } -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode(); MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); @@ -617,7 +617,7 @@ Status GatherPInfo::InferBias() { rank = rank % (params_strategy[0] * params_strategy[1]); } } -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) if (ps::PsDataPrefetch::GetInstance().cache_enable()) { bias_ = static_cast(ps::PsCacheManager::GetInstance().cache_indices_lower_bound()); return SUCCESS; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/unique_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/unique_info.cc index e3c455f8822..bcba19b0ec2 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/unique_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/unique_info.cc @@ -28,7 +28,7 @@ #include "frontend/parallel/strategy.h" #include "frontend/parallel/context.h" #include "frontend/parallel/tensor_layout/tensor_redistribution.h" -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) #include "ps/ps_cache/ps_cache_manager.h" #endif @@ -192,7 +192,7 @@ Status UniqueInfo::GenerateStrategies(int64_t stage_id) { return SUCCESS; } -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) Status UniqueInfo::ComputeReplaceGraph(const CNodePtr &cnode) { GenerateGraph gen_g = GenerateGraph(); if (gen_g.Init(cnode) != SUCCESS) { @@ -230,7 +230,7 @@ Status UniqueInfo::ComputeReplaceGraph(const CNodePtr &cnode) { #endif ReplaceGraphPtr UniqueInfo::replace_graph(const CNodePtr &cnode) { -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) if (ps::PsDataPrefetch::GetInstance().cache_enable()) { auto inputs = cnode->inputs(); if (inputs.empty()) { diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/unique_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/unique_info.h index b2037617f7b..159bf73953c 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/unique_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/unique_info.h @@ -51,7 +51,7 @@ class UniqueInfo : public OperatorInfo { Status InferMirrorOps() override; Status InferForwardCommunication() override { return SUCCESS; } Status InferAsLossDivisor() override { return SUCCESS; } -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) Status ComputeReplaceGraph(const CNodePtr &cnode); #endif diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc index e94df63850e..7343ac65dca 100644 --- a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc @@ -47,14 +47,14 @@ #include "ir/anf.h" #include "ir/param_info.h" #include "ir/tensor.h" -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) #include "ps/util.h" #endif namespace mindspore { namespace parallel { bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) { -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) if (ps::Util::IsRoleOfPServer() || ps::Util::IsRoleOfScheduler()) { return false; } diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 500162a7202..3ad85726c7f 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -46,7 +46,7 @@ #include "utils/ms_context.h" #include "utils/symbolic.h" #include "mindspore/core/utils/parallel_node_check.h" -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) #include "ps/util.h" #include "ps/ps_context.h" #endif @@ -3553,7 +3553,7 @@ static void HandleFullySplitParameters(const FuncGraphPtr &root) { } bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) { -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) { return false; } diff --git a/mindspore/ccsrc/minddata/dataset/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/CMakeLists.txt index 125bee998fb..5715a9b86e9 100644 --- a/mindspore/ccsrc/minddata/dataset/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/CMakeLists.txt @@ -295,7 +295,7 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Windows") target_link_libraries(_c_dataengine PRIVATE _c_mindrecord ${MINDRECORD_LINK_OBJECT} mindspore::sqlite) else() target_link_libraries(_c_dataengine PRIVATE _c_mindrecord) - if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) + if(ENABLE_CPU AND NOT WIN32) if(${ENABLE_IBVERBS} STREQUAL "ON") target_link_libraries(_c_dataengine PRIVATE ibverbs rdmacm) endif() diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/cache/CMakeLists.txt index 9f0123c26db..d566ac44ee2 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/CMakeLists.txt @@ -1,7 +1,8 @@ add_subdirectory(perf EXCLUDE_FROM_ALL) include_directories("${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache") set(MD_FLATBUFFER_OU "${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache") -ms_build_flatbuffers("de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${MD_FLATBUFFER_OU}) +set(FBS_FILES de_tensor.fbs) +ms_build_flatbuffers(FBS_FILES ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${MD_FLATBUFFER_OU}) file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index 3e4b6035727..dd56ac35331 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -43,7 +43,7 @@ #include "vm/transform.h" #include "parse/python_adapter.h" #include "frontend/optimizer/py_pass_manager.h" -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) #include "ps/parameter_server.h" #include "ps/scheduler.h" #include "ps/worker.h" @@ -606,7 +606,7 @@ bool ExecuteAction(const ResourcePtr &res) { return true; } -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) bool StartPSWorkerAction(const ResourcePtr &res) { ps::Worker::GetInstance().Run(); return true; @@ -782,7 +782,7 @@ std::vector VmPipeline() { actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction)); actions.emplace_back(std::make_pair("validate", ValidateAction)); -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) if (ps::PSContext::instance()->is_worker()) { actions.emplace_back(std::make_pair("worker", StartPSWorkerAction)); } @@ -796,7 +796,7 @@ std::vector VmPipeline() { return actions; } -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) std::vector PServerPipeline() { auto actions = CommonPipeline(); actions.emplace_back(std::make_pair("optimize", VmOptimizeAction)); diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc index 5f4f0523299..a6f62c74811 100644 --- a/mindspore/ccsrc/pipeline/jit/init.cc +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -34,7 +34,7 @@ #else #include "runtime/device/gpu/distribution/collective_fake_init.h" #endif -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) #include "ps/util.h" #endif #include "ps/ps_context.h" diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index 60a97e3ce47..6f65e95a94f 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -42,7 +42,7 @@ #include "pipeline/jit/pipeline_split.h" #include "pipeline/jit/static_analysis/auto_monad.h" #include "frontend/optimizer/irpass/gradient_eliminate.h" -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) #include "ps/util.h" #include "ps/ps_context.h" #endif @@ -407,7 +407,7 @@ bool AddRecomputationPass(const ResourcePtr &res) { } bool AddCacheEmbeddingPass(const ResourcePtr &res) { -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) if (ps::PSContext::instance()->is_ps_mode()) { return true; } diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index 15f8132cc2c..3636e92a543 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -49,7 +49,7 @@ #include "utils/shape_utils.h" #include "utils/info.h" #include "load_mindir/load_model.h" -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) #include "ps/constants.h" #include "ps/util.h" #include "ps/worker.h" @@ -528,7 +528,7 @@ std::vector GetPipeline(const ResourcePtr &resource, const std::stri std::string backend = MsContext::GetInstance()->backend_policy(); -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) if (ps::PSContext::instance()->is_server()) { resource->results()[kBackend] = compile::CreateBackend(); return PServerPipeline(); @@ -961,7 +961,7 @@ bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t ba bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size, const std::vector &types, const std::vector> &shapes, const std::vector &input_indexes, bool need_run) { -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) if ((ps::PSContext::instance()->is_ps_mode()) && (!ps::PSContext::instance()->is_worker())) { return true; } @@ -1027,7 +1027,7 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc } ConfigManager::GetInstance().set_iter_num(size); // PS cache does not support loop sink. -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) { ps::PsDataPrefetch::GetInstance().CreateDataChannel(queue_name, LongToSize(size)); ConfigManager::GetInstance().set_iter_num(1); @@ -1150,7 +1150,7 @@ void FinalizeBackend() { void ClearResAtexit() { MS_LOG(DEBUG) << "Pipeline clear all resource"; pynative::ClearPyNativeSession(); -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) if (ps::PSContext::instance()->is_ps_mode() && ps::PSContext::instance()->is_worker()) { if (ps::PsDataPrefetch::GetInstance().cache_enable()) { ps::ps_cache_instance.Finalize(); diff --git a/mindspore/ccsrc/ps/CMakeLists.txt b/mindspore/ccsrc/ps/CMakeLists.txt index a7752af1585..2245b793246 100644 --- a/mindspore/ccsrc/ps/CMakeLists.txt +++ b/mindspore/ccsrc/ps/CMakeLists.txt @@ -1,6 +1,13 @@ file(GLOB_RECURSE _PS_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") -if(NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) +set(SERVER_FLATBUFFER_OUTPUT "${CMAKE_BINARY_DIR}/schema") +set(FBS_FILES + ${CMAKE_CURRENT_SOURCE_DIR}/../../schema/cipher.fbs + ${CMAKE_CURRENT_SOURCE_DIR}/../../schema/fl_job.fbs + ) +ms_build_flatbuffers(FBS_FILES ${CMAKE_CURRENT_SOURCE_DIR}../../schema generated_fbs_files ${SERVER_FLATBUFFER_OUTPUT}) + +if(NOT ENABLE_CPU OR WIN32) list(REMOVE_ITEM _PS_SRC_FILES "optimizer_info_builder.cc") list(REMOVE_ITEM _PS_SRC_FILES "optimizer_info.cc") list(REMOVE_ITEM _PS_SRC_FILES "scheduler.cc") @@ -12,11 +19,6 @@ if(NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_client.cc") list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_message_handler.cc") list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_server.cc") - list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/communicator_base.cc") - list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/http_communicator.cc") - list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_communicator.cc") - list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/http_msg_handler.cc") - list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_msg_handler.cc") list(REMOVE_ITEM _PS_SRC_FILES "core/node.cc") list(REMOVE_ITEM _PS_SRC_FILES "core/node_manager.cc") list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc") @@ -39,18 +41,32 @@ if(NOT ENABLE_GPU) list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/gpu/gpu_ps_cache.cc") endif() -if(WIN32 OR NOT ENABLE_CPU) +if(NOT ENABLE_CPU OR WIN32) + list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/communicator_base.cc") + list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/http_communicator.cc") + list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_communicator.cc") + list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/http_msg_handler.cc") + list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_msg_handler.cc") list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/apply_momentum_kernel.cc") list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/aggregation_kernel_factory.cc") list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/dense_grad_accum_kernel.cc") list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/optimizer_kernel_factory.cc") + list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/round_kernel_factory.cc") + list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/round_kernel.cc") + list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/start_fl_job_kernel.cc") list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/params_info.cc") list(REMOVE_ITEM _PS_SRC_FILES "server/consistent_hash_ring.cc") list(REMOVE_ITEM _PS_SRC_FILES "server/iteration_timer.cc") - list(REMOVE_ITEM _PS_SRC_FILES "server/local_meta_storage.cc") + list(REMOVE_ITEM _PS_SRC_FILES "server/local_meta_store.cc") list(REMOVE_ITEM _PS_SRC_FILES "server/memory_register.cc") list(REMOVE_ITEM _PS_SRC_FILES "server/parameter_aggregator.cc") list(REMOVE_ITEM _PS_SRC_FILES "server/executor.cc") + list(REMOVE_ITEM _PS_SRC_FILES "server/collective_ops_impl.cc") + list(REMOVE_ITEM _PS_SRC_FILES "server/distributed_count_service.cc") + list(REMOVE_ITEM _PS_SRC_FILES "server/distributed_metadata_store.cc") + list(REMOVE_ITEM _PS_SRC_FILES "server/iteration.cc") + list(REMOVE_ITEM _PS_SRC_FILES "server/model_store.cc") + list(REMOVE_ITEM _PS_SRC_FILES "server/round.cc") endif() list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_prefetch.cc") @@ -59,3 +75,5 @@ add_subdirectory(ps_cache) set_property(SOURCE ${_PS_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS) add_library(_mindspore_ps_obj OBJECT ${_PS_SRC_FILES}) +add_dependencies(_mindspore_ps_obj generated_fbs_files) +target_link_libraries(_mindspore_ps_obj mindspore::flatbuffers) diff --git a/mindspore/ccsrc/ps/core/communicator/tcp_communicator.h b/mindspore/ccsrc/ps/core/communicator/tcp_communicator.h index 7d014083539..286fcbf59b6 100644 --- a/mindspore/ccsrc/ps/core/communicator/tcp_communicator.h +++ b/mindspore/ccsrc/ps/core/communicator/tcp_communicator.h @@ -34,13 +34,26 @@ namespace mindspore { namespace ps { namespace core { -enum class TcpUserCommand { kPush, kPull, kCount, kReachThreshold, kResetCount, kGetValue, kPutValue, kCounterEvent }; +enum class TcpUserCommand { + kPush, + kPull, + kCount, + kReachThreshold, + kResetCount, + kGetMetadata, + kUpdateMetadata, + kCounterEvent +}; const std::unordered_map kUserCommandToMsgType = { - {TcpUserCommand::kPush, "push"}, {TcpUserCommand::kPull, "pull"}, - {TcpUserCommand::kCount, "count"}, {TcpUserCommand::kReachThreshold, "reachThreshold"}, - {TcpUserCommand::kResetCount, "resetCnt"}, {TcpUserCommand::kGetValue, "getValue"}, - {TcpUserCommand::kPutValue, "putValue"}, {TcpUserCommand::kCounterEvent, "counterEvent"}, + {TcpUserCommand::kPush, "push"}, + {TcpUserCommand::kPull, "pull"}, + {TcpUserCommand::kCount, "count"}, + {TcpUserCommand::kReachThreshold, "countReachThreshold"}, + {TcpUserCommand::kResetCount, "resetCnt"}, + {TcpUserCommand::kGetMetadata, "getMetadata"}, + {TcpUserCommand::kUpdateMetadata, "updateMetadata"}, + {TcpUserCommand::kCounterEvent, "counterEvent"}, }; class TcpCommunicator : public CommunicatorBase { diff --git a/mindspore/ccsrc/ps/core/protos/fl.proto b/mindspore/ccsrc/ps/core/protos/fl.proto new file mode 100644 index 00000000000..daa63aa8507 --- /dev/null +++ b/mindspore/ccsrc/ps/core/protos/fl.proto @@ -0,0 +1,155 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +syntax = "proto3"; +package mindspore.ps; + +message CollectiveData { + bytes data = 1; +} + +message CountRequest { + string name = 1; + string id = 2; +} + +message CountResponse { + bool result = 1; + string reason = 2; +} + +message CountReachThresholdRequest { + string name = 1; +} + +message CountReachThresholdResponse { + bool is_enough = 1; +} + +message ResetCounterRequest { + string name = 1; +} + +message UpdateMetadataRequest { + string name = 1; + bytes value = 2; +} + +message GetMetadataRequest { + string name = 1; +} + +message GetMetadataResponse { + bytes value = 1; +} + +enum CounterEventType { + FIRST_CNT = 0; + LAST_CNT = 1; +} + +message CounterEvent { + CounterEventType type = 1; + string name = 2; + bytes data = 3; +} + +message FLId { + string fl_id = 1; +} + +message UpdateModelClientList { + repeated string fl_id = 1; +} + +message DeviceMeta { + string fl_name = 1; + string fl_id = 2; + uint64 data_size = 3; +} + +message FLIdToDeviceMeta { + map fl_id_to_meta = 1; +} + +message UpdateModelThreshold { + uint64 threshold = 1; +} + +message ClientShares { + map client_secret_shares = 1; +} + +message PairClientShares { + string fl_id = 1; + SharesPb client_shares = 2; +} + +message ClientKeys { + map client_keys = 1; +} + +message ClientNoises { + OneClientNoises one_client_noises = 1; +} + +message PairClientKeys { + string fl_id = 1; + KeysPb client_keys = 2; +} + +message OneClientNoises { + repeated float noise = 1; +} + +message ClientShareStr { + string fl_id = 1; + bytes share = 2; // todo: verify the correctness + int32 index = 3; +} + +message SharesPb { + repeated ClientShareStr clientsharestrs = 1; +} + +message KeysPb { + repeated bytes key = 1; +} + +message PBMetadata { + oneof value { + DeviceMeta device_meta = 1; + FLIdToDeviceMeta device_metas = 2; + + FLId fl_id = 3; + UpdateModelClientList client_list = 4; + + UpdateModelThreshold update_model_threshold = 5; + + PairClientShares pair_client_shares = 6; + ClientShares client_shares = 7; + + PairClientKeys pair_client_keys = 8; + ClientKeys client_keys = 9; + + OneClientNoises one_client_noises = 10; + ClientNoises client_noises = 11; + } +} + +message PBMetadataWithName { + string name = 1; + PBMetadata metadata = 2; +} diff --git a/mindspore/ccsrc/ps/core/protos/ps.proto b/mindspore/ccsrc/ps/core/protos/ps.proto index bc5c18246a5..090ef6bb5e8 100644 --- a/mindspore/ccsrc/ps/core/protos/ps.proto +++ b/mindspore/ccsrc/ps/core/protos/ps.proto @@ -60,4 +60,4 @@ message EmbeddingTableLookup { uint64 key = 2; repeated int32 keys = 3; repeated float values = 4; -} \ No newline at end of file +} diff --git a/mindspore/ccsrc/ps/ps_cache/CMakeLists.txt b/mindspore/ccsrc/ps/ps_cache/CMakeLists.txt index 6082f4fd867..524f448c165 100644 --- a/mindspore/ccsrc/ps/ps_cache/CMakeLists.txt +++ b/mindspore/ccsrc/ps/ps_cache/CMakeLists.txt @@ -1,4 +1,4 @@ -if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) +if(ENABLE_CPU AND NOT WIN32) file(GLOB_RECURSE _PS_CACHE_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ps_data/*.cc") set_property(SOURCE ${_PS_CACHE_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS) add_library(ps_cache SHARED ${_PS_CACHE_SRC_FILES}) diff --git a/mindspore/ccsrc/ps/ps_context.cc b/mindspore/ccsrc/ps/ps_context.cc index 039f8f792a3..222a3f31933 100644 --- a/mindspore/ccsrc/ps/ps_context.cc +++ b/mindspore/ccsrc/ps/ps_context.cc @@ -18,7 +18,7 @@ #include "utils/log_adapter.h" #include "utils/ms_utils.h" #include "backend/kernel_compiler/kernel.h" -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) #include "ps/ps_cache/ps_cache_manager.h" #include "ps/ps_cache/ps_data/ps_data_prefetch.h" #endif @@ -68,7 +68,7 @@ void PSContext::Reset() { is_worker_ = false; is_pserver_ = false; is_sched_ = false; -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) if (ps::PsDataPrefetch::GetInstance().cache_enable()) { ps_cache_instance.Finalize(); set_cache_enable(false); @@ -108,46 +108,62 @@ int PSContext::ps_rank_id() const { return rank_id_; } void PSContext::InsertHashTableSize(const std::string ¶m_name, size_t cache_vocab_size, size_t embedding_size, size_t vocab_size) const { -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) ps_cache_instance.InsertHashTableSize(param_name, cache_vocab_size, embedding_size, vocab_size); #endif } void PSContext::ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name, size_t cache_vocab_size, size_t embedding_size) const { -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) ps_cache_instance.ReInsertHashTableSize(new_param_name, cur_param_name, cache_vocab_size, embedding_size); #endif } void PSContext::InsertWeightInitInfo(const std::string ¶m_name, size_t global_seed, size_t op_seed) const { -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) ps_cache_instance.InsertWeightInitInfo(param_name, global_seed, op_seed); #endif } void PSContext::InsertAccumuInitInfo(const std::string ¶m_name, float init_val) const { -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) ps_cache_instance.InsertAccumuInitInfo(param_name, init_val); #endif } void PSContext::CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) const { -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) ps_cache_instance.CloneHashTable(dest_param_name, src_param_name); #endif } void PSContext::set_cache_enable(bool cache_enable) const { -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) PsDataPrefetch::GetInstance().set_cache_enable(cache_enable); #endif } void PSContext::set_rank_id(int rank_id) const { -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) ps_cache_instance.set_rank_id(rank_id); #endif } + +void PSContext::set_fl_name(const std::string &fl_name) { fl_name_ = fl_name; } + +const std::string &PSContext::fl_name() const { return fl_name_; } + +void PSContext::set_fl_iteration_num(uint64_t fl_iteration_num) { fl_iteration_num_ = fl_iteration_num; } + +uint64_t PSContext::fl_iteration_num() const { return fl_iteration_num_; } + +void PSContext::set_client_epoch_num(uint64_t client_epoch_num) { client_epoch_num_ = client_epoch_num; } + +uint64_t PSContext::client_epoch_num() const { return client_epoch_num_; } + +void PSContext::set_client_batch_size(uint64_t client_batch_size) { client_batch_size_ = client_batch_size; } + +uint64_t PSContext::client_batch_size() const { return client_batch_size_; } } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/ps_context.h b/mindspore/ccsrc/ps/ps_context.h index 1f474038a4e..6a506959359 100644 --- a/mindspore/ccsrc/ps/ps_context.h +++ b/mindspore/ccsrc/ps/ps_context.h @@ -60,6 +60,19 @@ class PSContext { void set_cache_enable(bool cache_enable) const; void set_rank_id(int rank_id) const; + // Setter and getter for federated learning. + void set_fl_name(const std::string &fl_name); + const std::string &fl_name() const; + + void set_fl_iteration_num(uint64_t fl_iteration_num); + uint64_t fl_iteration_num() const; + + void set_client_epoch_num(uint64_t client_epoch_num); + uint64_t client_epoch_num() const; + + void set_client_batch_size(uint64_t client_batch_size); + uint64_t client_batch_size() const; + private: PSContext() : ps_enabled_(false), @@ -80,6 +93,12 @@ class PSContext { uint32_t server_num_; std::string scheduler_host_; uint16_t scheduler_port_; + + // Members for federated learning. + std::string fl_name_; + uint64_t fl_iteration_num_; + uint64_t client_epoch_num_; + uint64_t client_batch_size_; }; } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/server/collective_ops_impl.cc b/mindspore/ccsrc/ps/server/collective_ops_impl.cc new file mode 100644 index 00000000000..22829d16839 --- /dev/null +++ b/mindspore/ccsrc/ps/server/collective_ops_impl.cc @@ -0,0 +1,223 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/server/collective_ops_impl.h" + +namespace mindspore { +namespace ps { +namespace server { +void CollectiveOpsImpl::Initialize(const std::shared_ptr &server_node) { + MS_EXCEPTION_IF_NULL(server_node); + server_node_ = server_node; + local_rank_ = server_node_->rank_id(); + server_num_ = PSContext::instance()->initial_server_num(); + return; +} + +template +bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size_t count) { + int ret = memcpy_s(recvbuff, count * sizeof(T), sendbuff, count * sizeof(T)); + if (ret != 0) { + MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; + return false; + } + + uint32_t rank_size = server_num_; + uint32_t local_rank_ = server_node_->rank_id(); + size_t chunk_size = count / rank_size; + size_t remainder_size = count % rank_size; + std::vector chunk_sizes(rank_size, chunk_size); + // The rest of the data should be assigned to each chunk. + for (size_t i = 0; i < remainder_size; i++) { + chunk_sizes[i]++; + } + // Store offsets to get every data chunk's address. + std::vector chunk_offset; + for (size_t i = 0; i < rank_size; i++) { + size_t ofs = + std::accumulate(chunk_sizes.begin(), chunk_sizes.begin() + i, static_cast(0), std::plus()); + chunk_offset.push_back(ofs); + } + + T *output_buff = reinterpret_cast(recvbuff); + uint32_t send_to_rank = (local_rank_ + 1) % rank_size; + uint32_t recv_from_rank = (local_rank_ - 1 + rank_size) % rank_size; + MS_LOG(DEBUG) << "AllReduce count:" << count << ", rank_size:" << rank_size << ", local_rank_:" << local_rank_ + << ", chunk_size:" << chunk_size << ", remainder_size:" << remainder_size + << ", chunk_sizes:" << chunk_sizes << ", send_to_rank:" << send_to_rank + << ", recv_from_rank:" << recv_from_rank; + + // Ring ReduceScatter. + MS_LOG(DEBUG) << "Start Ring ReduceScatter."; + std::unique_ptr tmp_recv_chunk = std::make_unique(chunk_sizes[0]); + for (size_t i = 0; i < rank_size - 1; i++) { + // Step 1: Async send data to next rank. + size_t send_chunk_index = (local_rank_ - i + rank_size) % rank_size; + T *send_chunk = output_buff + chunk_offset[send_chunk_index]; + auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, send_to_rank, send_chunk, + chunk_sizes[send_chunk_index] * sizeof(T)); + // Step 2: Async receive data to next rank and wait until it's done. + size_t recv_chunk_index = (local_rank_ - i - 1 + rank_size) % rank_size; + T *recv_chunk = output_buff + chunk_offset[recv_chunk_index]; + MS_LOG(DEBUG) << "Ring ReduceScatter send_to_rank:" << send_to_rank << ", recv_from_rank:" << recv_from_rank + << ", send count:" << chunk_sizes[send_chunk_index] + << ", recv count:" << chunk_sizes[recv_chunk_index] << ", iteration:" << i; + + std::shared_ptr> recv_str; + auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, recv_from_rank, &recv_str); + if (!server_node_->CollectiveWait(recv_req_id, 1)) { + MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed."; + return false; + } + memcpy_s(tmp_recv_chunk.get(), chunk_sizes[recv_chunk_index] * sizeof(T), recv_str->data(), recv_str->size()); + + // Step 3: Reduce the data so we can overlap the time cost of send. + for (size_t j = 0; j < chunk_sizes[recv_chunk_index]; j++) { + recv_chunk[j] += tmp_recv_chunk[j]; + } + // Step 4: Wait until send is done. + if (!server_node_->Wait(send_req_id, 1)) { + MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed."; + return false; + } + } + MS_LOG(DEBUG) << "End Ring ReduceScatter."; + + // Ring AllGather. + MS_LOG(DEBUG) << "Start Ring AllGather."; + for (size_t i = 0; i < rank_size - 1; i++) { + size_t send_chunk_index = (local_rank_ - i + 1 + rank_size) % rank_size; + T *send_chunk = output_buff + chunk_offset[send_chunk_index]; + auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, send_to_rank, send_chunk, + chunk_sizes[send_chunk_index] * sizeof(T)); + size_t recv_chunk_index = (local_rank_ - i + rank_size) % rank_size; + T *recv_chunk = output_buff + chunk_offset[recv_chunk_index]; + MS_LOG(DEBUG) << "Ring AllGather send_to_rank:" << send_to_rank << ", recv_from_rank:" << recv_from_rank + << ", send count:" << chunk_sizes[send_chunk_index] + << ", recv count:" << chunk_sizes[recv_chunk_index] << ", iteration:" << i; + + std::shared_ptr> recv_str; + auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, recv_from_rank, &recv_str); + + if (!server_node_->CollectiveWait(recv_req_id, 1)) { + MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed."; + return false; + } + memcpy_s(recv_chunk, chunk_sizes[recv_chunk_index] * sizeof(T), recv_str->data(), recv_str->size()); + if (!server_node_->Wait(send_req_id, 1)) { + MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed."; + return false; + } + } + MS_LOG(DEBUG) << "End Ring AllGather."; + return true; +} + +template +bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *recvbuff, size_t count) { + uint32_t rank_size = server_num_; + uint32_t local_rank_ = server_node_->rank_id(); + MS_LOG(DEBUG) << "Reduce Broadcast AllReduce rank_size:" << rank_size << ", local_rank_:" << local_rank_ + << ", count:" << count; + int ret = memcpy_s(recvbuff, count * sizeof(T), sendbuff, count * sizeof(T)); + if (ret != 0) { + MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; + return false; + } + T *output_buff = reinterpret_cast(recvbuff); + // Reduce data to rank 0 process. + MS_LOG(DEBUG) << "Start Reduce to rank 0 process."; + if (local_rank_ == 0) { + std::unique_ptr tmp_recv_buff = std::make_unique(count); + for (uint32_t i = 1; i < rank_size; i++) { + std::shared_ptr> recv_str; + MS_LOG(DEBUG) << "Reduce rank 0 receive from rank " << i; + auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, i, &recv_str); + if (!server_node_->CollectiveWait(recv_req_id, 1)) { + MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed."; + return false; + } + memcpy_s(tmp_recv_buff.get(), count * sizeof(T), recv_str->data(), recv_str->size()); + for (size_t j = 0; j < count; j++) { + output_buff[j] += tmp_recv_buff[j]; + } + } + } else { + MS_LOG(DEBUG) << "Reduce send data to rank 0 process."; + auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, 0, sendbuff, count * sizeof(T)); + if (!server_node_->Wait(send_req_id, 1)) { + MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed."; + return false; + } + } + MS_LOG(DEBUG) << "End Reduce."; + + // Broadcast data to not 0 rank process. + MS_LOG(DEBUG) << "Start broadcast from rank 0 to other processes."; + if (local_rank_ == 0) { + for (uint32_t i = 1; i < rank_size; i++) { + MS_LOG(DEBUG) << "Broadcast data to process " << i; + auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, i, output_buff, count * sizeof(T)); + if (!server_node_->Wait(send_req_id, 1)) { + MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed."; + return false; + } + } + } else { + MS_LOG(DEBUG) << "Broadcast receive from rank 0."; + std::shared_ptr> recv_str; + auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, 0, &recv_str); + if (!server_node_->CollectiveWait(recv_req_id, 1)) { + MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed."; + return false; + } + memcpy_s(output_buff, count * sizeof(T), recv_str->data(), recv_str->size()); + } + MS_LOG(DEBUG) << "End broadcast."; + return true; +} + +template +bool CollectiveOpsImpl::AllReduce(const void *sendbuff, void *recvbuff, size_t count) { + // The collective communication API does not support calling Send and Recv concurrently with multiple threads; + std::unique_lock lock(mtx_); + if (sendbuff == nullptr || recvbuff == nullptr) { + MS_LOG(ERROR) << "AllReduce sendbuff or recvbuff is nullptr."; + return false; + } + + uint32_t rank_size = server_num_; + if (count >= rank_size) { + return RingAllReduce(sendbuff, recvbuff, count); + } else { + return ReduceBroadcastAllReduce(sendbuff, recvbuff, count); + } +} + +template bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size_t count); +template bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size_t count); +template bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size_t count); + +template bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *recvbuff, size_t count); +template bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *recvbuff, size_t count); +template bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *recvbuff, size_t count); + +template bool CollectiveOpsImpl::AllReduce(const void *sendbuff, void *recvbuff, size_t count); +template bool CollectiveOpsImpl::AllReduce(const void *sendbuff, void *recvbuff, size_t count); +template bool CollectiveOpsImpl::AllReduce(const void *sendbuff, void *recvbuff, size_t count); +} // namespace server +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/server/collective_ops_impl.h b/mindspore/ccsrc/ps/server/collective_ops_impl.h new file mode 100644 index 00000000000..ce4f5b14bd0 --- /dev/null +++ b/mindspore/ccsrc/ps/server/collective_ops_impl.h @@ -0,0 +1,71 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_SERVER_COLLECTIVE_OPS_IMPL_H_ +#define MINDSPORE_CCSRC_PS_SERVER_COLLECTIVE_OPS_IMPL_H_ + +#include +#include +#include +#include +#include "proto/ps.pb.h" +#include "ps/ps_context.h" +#include "ps/core/server_node.h" +#include "ps/server/common.h" + +namespace mindspore { +namespace ps { +namespace server { +// CollectiveOpsImpl is the collective communication API of the server. +// For now, it implements two AllReduce algorithms: RingAllReduce and BroadcastAllReduce. Elastic AllReduce is also +// supported for the elastic scaling feature of the server. +class CollectiveOpsImpl { + public: + static CollectiveOpsImpl &GetInstance() { + static CollectiveOpsImpl instance; + return instance; + } + + void Initialize(const std::shared_ptr &server_node); + + template + bool AllReduce(const void *sendbuff, void *recvbuff, size_t count); + + private: + CollectiveOpsImpl() = default; + ~CollectiveOpsImpl() = default; + CollectiveOpsImpl(const CollectiveOpsImpl &) = delete; + CollectiveOpsImpl &operator=(const CollectiveOpsImpl &) = delete; + + // Implementation of RingAllReduce. + template + bool RingAllReduce(const void *sendbuff, void *recvbuff, size_t count); + + // Implementation of BroadcastAllReduce. + template + bool ReduceBroadcastAllReduce(const void *sendbuff, void *recvbuff, size_t count); + + std::shared_ptr server_node_; + uint32_t local_rank_; + uint32_t server_num_; + + // The mutex to ensure that collective communication is threadsafe. + std::mutex mtx_; +}; +} // namespace server +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_SERVER_COLLECTIVE_OPS_IMPL_H_ diff --git a/mindspore/ccsrc/ps/server/common.h b/mindspore/ccsrc/ps/server/common.h index 053f00024e9..cf931b52013 100644 --- a/mindspore/ccsrc/ps/server/common.h +++ b/mindspore/ccsrc/ps/server/common.h @@ -24,13 +24,17 @@ #include #include #include "proto/ps.pb.h" +#include "proto/fl.pb.h" #include "ir/anf.h" #include "utils/utils.h" #include "ir/dtype/type_id.h" #include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "schema/fl_job_generated.h" +#include "schema/cipher_generated.h" #include "ps/ps_context.h" #include "ps/core/communicator/http_message_handler.h" #include "ps/core/communicator/tcp_server.h" +#include "ps/core/communicator/message_handler.h" namespace mindspore { namespace ps { @@ -40,13 +44,15 @@ enum ServerMode { PARAMETER_SERVER = 0, FL_SERVER }; enum CommType { HTTP = 0, TCP }; enum AggregationType { FedAvg = 0, FedAdam, FedAdagarg, FedMeta, qffl, DenseGradAccum, SparseGradAccum }; -using kernel::Address; -using kernel::AddressPtr; -using kernel::CPUKernel; +using mindspore::kernel::Address; +using mindspore::kernel::AddressPtr; +using mindspore::kernel::CPUKernel; +using FBBuilder = flatbuffers::FlatBufferBuilder; using TimeOutCb = std::function; using StopTimerCb = std::function; using FinishIterCb = std::function; using FinalizeCb = std::function; +using MessageCallback = std::function &)>; // Information about whether server kernel will reuse kernel node memory from the front end. // Key refers to the server kernel's parameter name, like "weights", "grad", "learning_rate". diff --git a/mindspore/ccsrc/ps/server/distributed_count_service.cc b/mindspore/ccsrc/ps/server/distributed_count_service.cc new file mode 100644 index 00000000000..749b0a5eef9 --- /dev/null +++ b/mindspore/ccsrc/ps/server/distributed_count_service.cc @@ -0,0 +1,298 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/server/distributed_count_service.h" +#include +#include +#include + +namespace mindspore { +namespace ps { +namespace server { +void DistributedCountService::Initialize(const std::shared_ptr &server_node, + uint32_t counting_server_rank) { + server_node_ = server_node; + MS_EXCEPTION_IF_NULL(server_node_); + + communicator_ = + std::dynamic_pointer_cast(server_node_->GetOrCreateTcpComm("", 0, 0, 0, nullptr)); + MS_EXCEPTION_IF_NULL(communicator_); + + local_rank_ = server_node_->rank_id(); + server_num_ = PSContext::instance()->initial_server_num(); + counting_server_rank_ = counting_server_rank; + RegisterCallback(); + return; +} + +void DistributedCountService::RegisterCounter(const std::string &name, size_t global_threshold_count, + const CounterHandlers &counter_handlers) { + if (!counter_handlers.first_count_handler || !counter_handlers.last_count_handler) { + MS_LOG(EXCEPTION) << "First count handler or last count handler is not set."; + return; + } + if (global_threshold_count_.count(name) != 0) { + MS_LOG(ERROR) << "Counter for " << name << " is already set."; + return; + } + + MS_LOG(INFO) << "Rank " << local_rank_ << " register counter for " << name << " count:" << global_threshold_count; + // If the server is the leader server, it needs to set the counter handlers and do the real counting. + if (local_rank_ == counting_server_rank_) { + global_current_count_[name] = {}; + global_threshold_count_[name] = global_threshold_count; + mutex_[name]; + } + counter_handlers_[name] = counter_handlers; + return; +} + +bool DistributedCountService::Count(const std::string &name, const std::string &id) { + MS_LOG(INFO) << "Rank " << local_rank_ << " reports count for " << name << " of " << id; + if (local_rank_ == counting_server_rank_) { + if (global_threshold_count_.count(name) == 0) { + MS_LOG(ERROR) << "Counter for " << name << " is not registered."; + return false; + } + + std::unique_lock lock(mutex_[name]); + if (global_current_count_[name].size() >= global_threshold_count_[name]) { + MS_LOG(ERROR) << "Count for " << name << " is already enough. Threshold count is " + << global_threshold_count_[name]; + return false; + } + + MS_LOG(INFO) << "Leader server increase count for " << name << " of " << id; + global_current_count_[name].insert(id); + TriggerCounterEvent(name); + } else { + // If this server is a follower server, it needs to send CountRequest to the leader server. + CountRequest report_count_req; + report_count_req.set_name(name); + report_count_req.set_id(id); + + std::shared_ptr> report_cnt_rsp_msg = nullptr; + if (!communicator_->SendPbRequest(report_count_req, counting_server_rank_, core::TcpUserCommand::kCount, + &report_cnt_rsp_msg)) { + MS_LOG(ERROR) << "Sending reporting count message to leader server failed for " << name; + return false; + } + + CountResponse count_rsp; + count_rsp.ParseFromArray(report_cnt_rsp_msg->data(), report_cnt_rsp_msg->size()); + if (!count_rsp.result()) { + MS_LOG(ERROR) << "Reporting count failed:" << count_rsp.reason(); + return false; + } + } + return true; +} + +bool DistributedCountService::CountReachThreshold(const std::string &name) { + MS_LOG(INFO) << "Rank " << local_rank_ << " query whether count reaches threshold for " << name; + if (local_rank_ == counting_server_rank_) { + if (global_threshold_count_.count(name) == 0) { + MS_LOG(ERROR) << "Counter for " << name << " is not set."; + return false; + } + + std::unique_lock lock(mutex_[name]); + return global_current_count_[name].size() == global_threshold_count_[name]; + } else { + CountReachThresholdRequest count_reach_threashold_req; + count_reach_threashold_req.set_name(name); + + std::shared_ptr> query_cnt_enough_rsp_msg = nullptr; + if (!communicator_->SendPbRequest(count_reach_threashold_req, counting_server_rank_, + core::TcpUserCommand::kReachThreshold, &query_cnt_enough_rsp_msg)) { + MS_LOG(ERROR) << "Sending querying whether count reaches threshold message to leader server failed for " << name; + return false; + } + + CountReachThresholdResponse count_reach_threashold_rsp; + count_reach_threashold_rsp.ParseFromArray(query_cnt_enough_rsp_msg->data(), query_cnt_enough_rsp_msg->size()); + return count_reach_threashold_rsp.is_enough(); + } +} + +void DistributedCountService::ResetCounter(const std::string &name) { + if (local_rank_ == counting_server_rank_) { + MS_LOG(INFO) << "Leader server reset count for " << name; + global_current_count_[name].clear(); + } + return; +} + +void DistributedCountService::RegisterCallback() { + if (local_rank_ == counting_server_rank_) { + communicator_->RegisterMsgCallBack( + "count", std::bind(&DistributedCountService::HandleCountRequest, this, std::placeholders::_1)); + communicator_->RegisterMsgCallBack( + "countReachThreshold", + std::bind(&DistributedCountService::HandleCountReachThresholdRequest, this, std::placeholders::_1)); + } + + // The callback of first/last event must be set in both leader server and follower servers. + communicator_->RegisterMsgCallBack( + "counterEvent", std::bind(&DistributedCountService::HandleCounterEvent, this, std::placeholders::_1)); +} + +void DistributedCountService::HandleCountRequest(const std::shared_ptr &message) { + if (message == nullptr) { + MS_LOG(ERROR) << "Message is nullptr."; + return; + } + + CountRequest report_count_req; + report_count_req.ParseFromArray(message->data(), message->len()); + const std::string &name = report_count_req.name(); + const std::string &id = report_count_req.id(); + + CountResponse count_rsp; + std::unique_lock lock(mutex_[name]); + // If leader server has no counter for the name registered, return an error. + if (global_threshold_count_.count(name) == 0) { + std::string reason = "Counter for " + name + " is not registered."; + count_rsp.set_result(false); + count_rsp.set_reason(reason); + MS_LOG(ERROR) << reason; + communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(), message); + return; + } + + // If leader server already has enough count for the name, return an error. + if (global_current_count_[name].size() >= global_threshold_count_[name]) { + std::string reason = + "Count for " + name + " is already enough. Threshold count is " + std::to_string(global_threshold_count_[name]); + count_rsp.set_result(false); + count_rsp.set_reason(reason); + MS_LOG(ERROR) << reason; + communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(), message); + return; + } + + // Insert the id for the counter, which means the count for the name is increased. + MS_LOG(INFO) << "Leader server increase count for " << name << " of " << id; + global_current_count_[name].insert(id); + TriggerCounterEvent(name); + count_rsp.set_result(true); + count_rsp.set_reason("success"); + communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(), message); + return; +} + +void DistributedCountService::HandleCountReachThresholdRequest(const std::shared_ptr &message) { + if (message == nullptr) { + MS_LOG(ERROR) << "Message is nullptr."; + return; + } + + CountReachThresholdRequest count_reach_threashold_req; + count_reach_threashold_req.ParseFromArray(message->data(), message->len()); + const std::string &name = count_reach_threashold_req.name(); + + std::unique_lock lock(mutex_[name]); + if (global_threshold_count_.count(name) == 0) { + MS_LOG(ERROR) << "Counter for " << name << " is not registered."; + return; + } + + CountReachThresholdResponse count_reach_threashold_rsp; + count_reach_threashold_rsp.set_is_enough(global_current_count_[name].size() == global_threshold_count_[name]); + communicator_->SendResponse(count_reach_threashold_rsp.SerializeAsString().data(), + count_reach_threashold_rsp.SerializeAsString().size(), message); + return; +} + +void DistributedCountService::HandleCounterEvent(const std::shared_ptr &message) { + if (message == nullptr) { + MS_LOG(ERROR) << "Message is nullptr."; + return; + } + + // Respond as soon as possible so the leader server won't wait for each follower servers to finish calling the + // callbacks. + std::string couter_event_rsp_msg = "success"; + communicator_->SendResponse(couter_event_rsp_msg.data(), couter_event_rsp_msg.size(), message); + + CounterEvent counter_event; + counter_event.ParseFromArray(message->data(), message->len()); + const auto &type = counter_event.type(); + const auto &name = counter_event.name(); + + MS_LOG(INFO) << "Rank " << local_rank_ << " do counter event " << type << " for " << name; + if (type == CounterEventType::FIRST_CNT) { + counter_handlers_[name].first_count_handler(message); + } else if (type == CounterEventType::LAST_CNT) { + counter_handlers_[name].last_count_handler(message); + } else { + MS_LOG(ERROR) << "DistributedCountService event type " << type << " is invalid."; + return; + } + return; +} + +void DistributedCountService::TriggerCounterEvent(const std::string &name) { + MS_LOG(INFO) << "Current count for " << name << " is " << global_current_count_[name].size() + << ", threshold count is " << global_threshold_count_[name]; + // The threshold count may be 1 so the first and last count event should be both activated. + if (global_current_count_[name].size() == 1) { + TriggerFirstCountEvent(name); + } + if (global_current_count_[name].size() == global_threshold_count_[name]) { + TriggerLastCountEvent(name); + } + return; +} + +void DistributedCountService::TriggerFirstCountEvent(const std::string &name) { + MS_LOG(INFO) << "Activating first count event for " << name; + CounterEvent first_count_event; + first_count_event.set_type(CounterEventType::FIRST_CNT); + first_count_event.set_name(name); + + // Broadcast to all follower servers. + for (uint32_t i = 1; i < server_num_; i++) { + if (!communicator_->SendPbRequest(first_count_event, i, core::TcpUserCommand::kCounterEvent)) { + MS_LOG(ERROR) << "Activating first count event to server " << i << " failed."; + return; + } + } + // Leader server directly calls the callback. + counter_handlers_[name].first_count_handler(nullptr); + return; +} + +void DistributedCountService::TriggerLastCountEvent(const std::string &name) { + MS_LOG(INFO) << "Activating last count event for " << name; + CounterEvent last_count_event; + last_count_event.set_type(CounterEventType::LAST_CNT); + last_count_event.set_name(name); + + // Broadcast to all follower servers. + for (uint32_t i = 1; i < server_num_; i++) { + if (!communicator_->SendPbRequest(last_count_event, i, core::TcpUserCommand::kCounterEvent)) { + MS_LOG(ERROR) << "Activating last count event to server " << i << " failed."; + return; + } + } + // Leader server directly calls the callback. + counter_handlers_[name].last_count_handler(nullptr); + return; +} +} // namespace server +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/server/distributed_count_service.h b/mindspore/ccsrc/ps/server/distributed_count_service.h new file mode 100644 index 00000000000..4407f61836d --- /dev/null +++ b/mindspore/ccsrc/ps/server/distributed_count_service.h @@ -0,0 +1,126 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_COUNT_SERVICE_H_ +#define MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_COUNT_SERVICE_H_ + +#include +#include +#include +#include +#include "proto/ps.pb.h" +#include "ps/server/common.h" +#include "ps/core/server_node.h" +#include "ps/core/communicator/tcp_communicator.h" + +namespace mindspore { +namespace ps { +namespace server { +// The callbacks for the first count and last count event. +typedef struct { + MessageCallback first_count_handler; + MessageCallback last_count_handler; +} CounterHandlers; + +// DistributedCountService is used for counting in the server cluster dimension. It's used for counting of rounds, +// aggregation counting, etc. + +// The counting could be called by any server, but only one server has the information +// of the cluster count and we mark this server as the counting server. Other servers must communicate with this +// counting server to increase/query count number. + +// On the first count or last count event, DistributedCountService on the counting server triggers the event on other +// servers by sending counter event commands. This is for the purpose of keeping server cluster's consistency. +class DistributedCountService { + public: + static DistributedCountService &GetInstance() { + static DistributedCountService instance; + return instance; + } + + // Initialize counter service with the server node because communication is needed. + void Initialize(const std::shared_ptr &server_node, uint32_t counting_server_rank); + + // Register counter to the counting server for the name with its threshold count in server cluster dimension and + // first/last count event callbacks. + void RegisterCounter(const std::string &name, size_t global_threshold_count, const CounterHandlers &counter_handlers); + + // Report a count to the counting server. Parameter 'id' is in case of repeated counting. + bool Count(const std::string &name, const std::string &id); + + // Query whether the count reaches the threshold count for the name. If the count is the same as the threshold count, + // this method returns true. + bool CountReachThreshold(const std::string &name); + + // Reset the count of the name to 0. + void ResetCounter(const std::string &name); + + // Returns the server rank because in some cases the callers use this rank as the 'id' for method + // Count. + uint32_t local_rank() { return local_rank_; } + + private: + DistributedCountService() = default; + ~DistributedCountService() = default; + DistributedCountService(const DistributedCountService &) = delete; + DistributedCountService &operator=(const DistributedCountService &) = delete; + + // Register callbacks of the counting server to handle messages sent by the other servers. + void RegisterCallback(); + + // Callback for the reporting count message from other servers. Only counting server will call this method. + void HandleCountRequest(const std::shared_ptr &message); + + // Callback for the querying whether threshold count is reached message from other servers. Only counting + // server will call this method. + void HandleCountReachThresholdRequest(const std::shared_ptr &message); + + // Callback for the first/last event message from the counting server. Only other servers will call this + // method. + void HandleCounterEvent(const std::shared_ptr &message); + + // Call the callbacks when the first/last count event is triggered. + void TriggerCounterEvent(const std::string &name); + void TriggerFirstCountEvent(const std::string &name); + void TriggerLastCountEvent(const std::string &name); + + // Members for the communication between counting server and other servers. + std::shared_ptr server_node_; + std::shared_ptr communicator_; + uint32_t local_rank_; + uint32_t server_num_; + + // Only one server will be set to do the real counting. + uint32_t counting_server_rank_; + + // Key: name, e.g, startFLJob, updateModel, push. + // Value: a set of id without repeatation because each work may report multiple times. + std::unordered_map> global_current_count_; + + // Key: name, e.g, StartFLJobCount. + // Value: global threshold count in the server cluster dimension for this name. + std::unordered_map global_threshold_count_; + + // First/last count event callbacks of the name. + std::unordered_map counter_handlers_; + + // Because the count is increased/queried conccurently, we must ensure the operations are threadsafe. + std::unordered_map mutex_; +}; +} // namespace server +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_COUNT_SERVICE_H_ diff --git a/mindspore/ccsrc/ps/server/distributed_metadata_store.cc b/mindspore/ccsrc/ps/server/distributed_metadata_store.cc new file mode 100644 index 00000000000..b4e49c64ea3 --- /dev/null +++ b/mindspore/ccsrc/ps/server/distributed_metadata_store.cc @@ -0,0 +1,201 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/server/distributed_metadata_store.h" +#include +#include +#include + +namespace mindspore { +namespace ps { +namespace server { +void DistributedMetadataStore::Initialize(const std::shared_ptr &server_node) { + server_node_ = server_node; + MS_EXCEPTION_IF_NULL(server_node); + + communicator_ = + std::dynamic_pointer_cast(server_node_->GetOrCreateTcpComm("", 0, 0, 0, nullptr)); + MS_EXCEPTION_IF_NULL(communicator_); + + local_rank_ = server_node_->rank_id(); + server_num_ = PSContext::instance()->initial_server_num(); + + InitHashRing(); + RegisterCallback(); + return; +} + +void DistributedMetadataStore::RegisterMetadata(const std::string &name, const PBMetadata &meta) { + if (router_ == nullptr) { + MS_LOG(ERROR) << "The consistent hash ring is not initialized yet."; + return; + } + + uint32_t stored_rank = router_->Find(name); + if (local_rank_ == stored_rank) { + if (metadata_.count(name) != 0) { + MS_LOG(ERROR) << "The metadata for " << name << " is already registered."; + return; + } + + MS_LOG(INFO) << "Rank " << local_rank_ << " register storage for metadata " << name; + metadata_[name] = meta; + mutex_[name]; + } + return; +} + +void DistributedMetadataStore::ResetMetadata(const std::string &name) { + if (router_ == nullptr) { + MS_LOG(ERROR) << "The consistent hash ring is not initialized yet."; + return; + } + + uint32_t stored_rank = router_->Find(name); + if (local_rank_ == stored_rank) { + if (metadata_.count(name) == 0) { + MS_LOG(ERROR) << "The metadata for " << name << " is not registered."; + return; + } + + MS_LOG(INFO) << "Rank " << local_rank_ << " reset metadata for " << name; + std::unique_lock lock(mutex_[name]); + PBMetadata empty_meta; + metadata_[name] = empty_meta; + } + return; +} + +void DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBMetadata &meta) { + if (router_ == nullptr) { + MS_LOG(ERROR) << "The consistent hash ring is not initialized yet."; + return; + } + + uint32_t stored_rank = router_->Find(name); + MS_LOG(INFO) << "Rank " << local_rank_ << " update value for " << name << " which is stored in rank " << stored_rank; + if (local_rank_ == stored_rank) { + if (!DoUpdateMetadata(name, meta)) { + MS_LOG(ERROR) << "Updating meta data failed."; + return; + } + } else { + PBMetadataWithName metadata_with_name; + metadata_with_name.set_name(name); + *metadata_with_name.mutable_metadata() = meta; + if (!communicator_->SendPbRequest(metadata_with_name, stored_rank, core::TcpUserCommand::kUpdateMetadata)) { + MS_LOG(ERROR) << "Sending updating metadata message to server " << stored_rank << " failed."; + return; + } + } + return; +} + +PBMetadata DistributedMetadataStore::GetMetadata(const std::string &name) { + if (router_ == nullptr) { + MS_LOG(ERROR) << "The consistent hash ring is not initialized yet."; + return {}; + } + + uint32_t stored_rank = router_->Find(name); + MS_LOG(INFO) << "Rank " << local_rank_ << " get metadata for " << name << " which is stored in rank " << stored_rank; + if (local_rank_ == stored_rank) { + std::unique_lock lock(mutex_[name]); + return metadata_[name]; + } else { + GetMetadataRequest get_metadata_req; + get_metadata_req.set_name(name); + PBMetadata get_metadata_rsp; + + std::shared_ptr> get_meta_rsp_msg = nullptr; + if (!communicator_->SendPbRequest(get_metadata_req, stored_rank, core::TcpUserCommand::kGetMetadata, + &get_meta_rsp_msg)) { + MS_LOG(ERROR) << "Sending getting metadata message to server " << stored_rank << " failed."; + return get_metadata_rsp; + } + get_metadata_rsp.ParseFromArray(get_meta_rsp_msg->data(), get_meta_rsp_msg->size()); + return get_metadata_rsp; + } +} + +void DistributedMetadataStore::InitHashRing() { + router_ = std::make_shared(32); + MS_EXCEPTION_IF_NULL(router_); + for (uint32_t i = 0; i < server_num_; i++) { + bool ret = router_->Insert(i); + if (!ret) { + MS_LOG(EXCEPTION) << "Add node " << i << " to router of meta storage failed."; + return; + } + } + return; +} + +void DistributedMetadataStore::RegisterCallback() { + communicator_->RegisterMsgCallBack( + "updateMetadata", std::bind(&DistributedMetadataStore::HandleUpdateMetadataRequest, this, std::placeholders::_1)); + communicator_->RegisterMsgCallBack( + "getMetadata", std::bind(&DistributedMetadataStore::HandleGetMetadataRequest, this, std::placeholders::_1)); + return; +} + +void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr &message) { + if (message == nullptr) { + MS_LOG(ERROR) << "Message is nullptr."; + return; + } + + PBMetadataWithName meta_with_name; + meta_with_name.ParseFromArray(message->data(), message->len()); + const std::string &name = meta_with_name.name(); + MS_LOG(INFO) << "Update metadata for " << name; + + std::string update_meta_rsp_msg; + if (!DoUpdateMetadata(name, meta_with_name.metadata())) { + update_meta_rsp_msg = "Updating meta data failed."; + } else { + update_meta_rsp_msg = "Success"; + } + communicator_->SendResponse(update_meta_rsp_msg.data(), update_meta_rsp_msg.size(), message); + return; +} + +void DistributedMetadataStore::HandleGetMetadataRequest(const std::shared_ptr &message) { + if (message == nullptr) { + MS_LOG(ERROR) << "Message is nullptr."; + return; + } + + GetMetadataRequest get_metadata_req; + get_metadata_req.ParseFromArray(message->data(), message->len()); + const std::string &name = get_metadata_req.name(); + MS_LOG(INFO) << "Getting metadata for " << name; + + std::unique_lock lock(mutex_[name]); + PBMetadata stored_meta = metadata_[name]; + std::string getting_meta_rsp_msg = stored_meta.SerializeAsString(); + communicator_->SendResponse(getting_meta_rsp_msg.data(), getting_meta_rsp_msg.size(), message); + return; +} + +bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const PBMetadata &meta) { + std::unique_lock lock(mutex_[name]); + metadata_[name] = meta; + return true; +} +} // namespace server +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/server/distributed_metadata_store.h b/mindspore/ccsrc/ps/server/distributed_metadata_store.h new file mode 100644 index 00000000000..b8ffe25b27f --- /dev/null +++ b/mindspore/ccsrc/ps/server/distributed_metadata_store.h @@ -0,0 +1,101 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_META_STORE_H_ +#define MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_META_STORE_H_ + +#include +#include +#include +#include "proto/ps.pb.h" +#include "ps/server/common.h" +#include "ps/core/server_node.h" +#include "ps/core/communicator/tcp_communicator.h" +#include "ps/server/consistent_hash_ring.h" + +namespace mindspore { +namespace ps { +namespace server { +// This class is used for distributed metadata storage using consistent hash. All metadata is distributedly +// stored in all servers. Caller doesn't need to know which server stores the metadata. It only needs to know what kind +// of operations should be done to the metadata. + +// The metadata stored in the server is in protobuffer format because it's easy for serializing and communicating. The +// type of the protobuffer struct is decided by the caller using protobuffer's API. +class DistributedMetadataStore { + public: + static DistributedMetadataStore &GetInstance() { + static DistributedMetadataStore instance; + return instance; + } + + // Initialize metadata storage with the server node because communication is needed. + void Initialize(const std::shared_ptr &server_node); + + // Register metadata for the name with the initial value. This method should be only called once for each name. + void RegisterMetadata(const std::string &name, const PBMetadata &meta); + + // Reset the metadata value for the name. + void ResetMetadata(const std::string &name); + + // Update the metadata for the name. + void UpdateMetadata(const std::string &name, const PBMetadata &meta); + + // Get the metadata for the name. + PBMetadata GetMetadata(const std::string &name); + + private: + DistributedMetadataStore() = default; + ~DistributedMetadataStore() = default; + DistributedMetadataStore(const DistributedMetadataStore &) = delete; + DistributedMetadataStore &operator=(const DistributedMetadataStore &) = delete; + + // Initialize the consistent hash ring for distributed storage. + void InitHashRing(); + + // Register callbacks for the server to handle update/get metadata messages from other servers. + void RegisterCallback(); + + // Callback for updating metadata request sent to the server. + void HandleUpdateMetadataRequest(const std::shared_ptr &message); + + // Callback for getting metadata request sent to the server. + void HandleGetMetadataRequest(const std::shared_ptr &message); + + // Do updating metadata in the server where the metadata for the name is stored. + bool DoUpdateMetadata(const std::string &name, const PBMetadata &meta); + + // Members for the communication between servers. + std::shared_ptr server_node_; + std::shared_ptr communicator_; + uint32_t local_rank_; + uint32_t server_num_; + + // Consistent hash ring. This is used for DistributedMetadataStore to find which server node the meta data is stored. + std::shared_ptr router_; + + // We store metadata which is serialized by ProtoBuffer so that data storage and data transmission API is easy to use. + // Key: data name. + // Value: ProtoBuffer Struct. + std::unordered_map metadata_; + + // Because the metadata is read/written conccurently, we must ensure the operations are threadsafe. + std::unordered_map mutex_; +}; +} // namespace server +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_META_STORE_H_ diff --git a/mindspore/ccsrc/ps/server/executor.cc b/mindspore/ccsrc/ps/server/executor.cc index 411ca988531..949f4f0f131 100644 --- a/mindspore/ccsrc/ps/server/executor.cc +++ b/mindspore/ccsrc/ps/server/executor.cc @@ -169,7 +169,7 @@ bool Executor::HandleOverwriteWeightsByKey(const std::map } AddressPtr Executor::HandlePull(const std::string ¶m_name) { - MS_LOG(INFO) << "Handle blocking pull msg for parameter " << param_name; + MS_LOG(INFO) << "Handle blocking pull message for parameter " << param_name; if (param_aggrs_.count(param_name) == 0) { MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server."; return nullptr; @@ -193,11 +193,6 @@ AddressPtr Executor::HandlePull(const std::string ¶m_name) { return addr; } -std::map Executor::HandleAsyncGetModel() { - std::unique_lock lock(model_mutex_); - return GetModel(); -} - std::map Executor::HandleGetWeightsByKey(const std::vector ¶m_names) { std::map weights; for (const auto ¶m_name : param_names) { diff --git a/mindspore/ccsrc/ps/server/executor.h b/mindspore/ccsrc/ps/server/executor.h index 0befb3dd304..ef4eafb6af6 100644 --- a/mindspore/ccsrc/ps/server/executor.h +++ b/mindspore/ccsrc/ps/server/executor.h @@ -63,10 +63,6 @@ class Executor { // asynchronously. bool HandleModelUpdateAsync(const std::map &feature_map); - // Called in asynchronous federated learning training mode. Returns whole model in key-value where key refers to the - // parameter name. - std::map HandleAsyncGetModel(); - // Forcibly overwrite specific weights in overwriteWeights message. bool HandleOverwriteWeightsByKey(const std::map &feature_map); diff --git a/mindspore/ccsrc/ps/server/iteration.cc b/mindspore/ccsrc/ps/server/iteration.cc new file mode 100644 index 00000000000..85b3c6b8685 --- /dev/null +++ b/mindspore/ccsrc/ps/server/iteration.cc @@ -0,0 +1,76 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/server/iteration.h" +#include +#include +#include +#include "ps/server/model_store.h" + +namespace mindspore { +namespace ps { +namespace server { +Iteration::Iteration() : iteration_num_(1) { LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_); } + +void Iteration::AddRound(const std::shared_ptr &round) { + MS_EXCEPTION_IF_NULL(round); + rounds_.push_back(round); +} + +void Iteration::InitRounds(const std::vector> &communicators, + const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb) { + if (communicators.empty()) { + MS_LOG(EXCEPTION) << "Communicators for rounds is empty."; + return; + } + + std::for_each(communicators.begin(), communicators.end(), + [&](const std::shared_ptr &communicator) { + for (auto &round : rounds_) { + if (round == nullptr) { + continue; + } + round->Initialize(communicator, timeout_cb, finish_iteration_cb); + } + }); + + // The time window for one iteration, which will be used in some round kernels. + size_t iteration_time_window = + std::accumulate(rounds_.begin(), rounds_.end(), 0, + [](size_t total, const std::shared_ptr &round) { return total + round->time_window(); }); + LocalMetaStore::GetInstance().put_value(kCtxTotalTimeoutDuration, iteration_time_window); + return; +} + +void Iteration::ProceedToNextIter() { + iteration_num_ = LocalMetaStore::GetInstance().curr_iter_num(); + // Store the model for each iteration. + const auto &model = Executor::GetInstance().GetModel(); + ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model); + + for (auto &round : rounds_) { + round->Reset(); + } + + iteration_num_++; + LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_); + MS_LOG(INFO) << "Proceed to next iteration:" << iteration_num_ << "\n"; +} + +const std::vector> &Iteration::rounds() { return rounds_; } +} // namespace server +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/server/iteration.h b/mindspore/ccsrc/ps/server/iteration.h new file mode 100644 index 00000000000..0da4d8f9b7c --- /dev/null +++ b/mindspore/ccsrc/ps/server/iteration.h @@ -0,0 +1,58 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_SERVER_ITERATION_H_ +#define MINDSPORE_CCSRC_PS_SERVER_ITERATION_H_ + +#include +#include +#include "ps/core/communicator/communicator_base.h" +#include "ps/server/common.h" +#include "ps/server/round.h" +#include "ps/server/local_meta_store.h" + +namespace mindspore { +namespace ps { +namespace server { +// In server's logic, Iteration is the minimum execution unit. For each execution, it consists of multiple kinds of +// Rounds, only after all the rounds are finished, this iteration is considered as completed. +class Iteration { + public: + Iteration(); + ~Iteration() = default; + + // Add a round for the iteration. This method will be called multiple times for each round. + void AddRound(const std::shared_ptr &round); + + // Initialize all the rounds in the iteration. + void InitRounds(const std::vector> &communicators, + const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb); + + // The server proceeds to the next iteration only after the last iteration finishes. + void ProceedToNextIter(); + + const std::vector> &rounds(); + + private: + std::vector> rounds_; + + // Server's current iteration number. + size_t iteration_num_; +}; +} // namespace server +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_SERVER_ITERATION_H_ diff --git a/mindspore/ccsrc/ps/server/kernel/round/round_kernel.cc b/mindspore/ccsrc/ps/server/kernel/round/round_kernel.cc new file mode 100644 index 00000000000..08dbe29a413 --- /dev/null +++ b/mindspore/ccsrc/ps/server/kernel/round/round_kernel.cc @@ -0,0 +1,127 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/server/kernel/round/round_kernel.h" +#include +#include +#include +#include +#include +#include +#include + +namespace mindspore { +namespace ps { +namespace server { +namespace kernel { +RoundKernel::RoundKernel() : name_(""), current_count_(0), required_count_(0), error_reason_("") { + release_thread_ = std::thread([&]() { + while (true) { + std::unique_lock release_lock(release_mtx_); + // Detect whether there's any data needs to be released every 100 milliseconds. + if (heap_data_to_release_.empty()) { + release_lock.unlock(); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + continue; + } + + AddressPtr addr_ptr = heap_data_to_release_.front(); + heap_data_to_release_.pop(); + release_lock.unlock(); + + std::unique_lock heap_data_lock(heap_data_mtx_); + if (heap_data_.count(addr_ptr) == 0) { + MS_LOG(ERROR) << "The data is not stored."; + continue; + } + // Manually release unique_ptr data. + heap_data_[addr_ptr].reset(nullptr); + heap_data_.erase(heap_data_.find(addr_ptr)); + } + }); + release_thread_.detach(); +} + +void RoundKernel::OnFirstCountEvent(const std::shared_ptr &message) { return; } + +void RoundKernel::OnLastCountEvent(const std::shared_ptr &message) { return; } + +void RoundKernel::StopTimer() { + if (stop_timer_cb_) { + stop_timer_cb_(); + } + return; +} + +void RoundKernel::FinishIteration() { + if (finish_iteration_cb_) { + finish_iteration_cb_(); + } + return; +} + +void RoundKernel::Release(AddressPtr addr_ptr) { + if (addr_ptr == nullptr) { + MS_LOG(ERROR) << "Data to be released is empty."; + return; + } + std::unique_lock lock(release_mtx_); + heap_data_to_release_.push(addr_ptr); + return; +} + +void RoundKernel::set_name(const std::string &name) { name_ = name; } + +void RoundKernel::set_stop_timer_cb(StopTimerCb timer_stopper) { stop_timer_cb_ = timer_stopper; } + +void RoundKernel::set_finish_iteration_cb(FinishIterCb finish_iteration_cb) { + finish_iteration_cb_ = finish_iteration_cb; +} + +void RoundKernel::GenerateOutput(const std::vector &outputs, void *data, size_t len) { + if (data == nullptr) { + MS_LOG(ERROR) << "The data is nullptr."; + return; + } + + if (outputs.empty()) { + MS_LOG(ERROR) << "Generating output failed. Outputs size is empty."; + return; + } + + std::unique_ptr output_data = std::make_unique(len); + if (output_data == nullptr) { + MS_LOG(ERROR) << "Output data is nullptr."; + return; + } + + size_t dst_size = len; + int ret = memcpy_s(output_data.get(), dst_size, data, len); + if (ret != 0) { + MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; + return; + } + outputs[0]->addr = output_data.get(); + outputs[0]->size = len; + + std::unique_lock lock(heap_data_mtx_); + heap_data_.insert(std::make_pair(outputs[0], std::move(output_data))); + return; +} +} // namespace kernel +} // namespace server +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/server/kernel/round/round_kernel.h b/mindspore/ccsrc/ps/server/kernel/round/round_kernel.h new file mode 100644 index 00000000000..cd4861914cf --- /dev/null +++ b/mindspore/ccsrc/ps/server/kernel/round/round_kernel.h @@ -0,0 +1,130 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_ +#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/common_utils.h" +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "ps/server/common.h" +#include "ps/server/local_meta_store.h" +#include "ps/server/distributed_count_service.h" +#include "ps/server/distributed_metadata_store.h" + +namespace mindspore { +namespace ps { +namespace server { +namespace kernel { +// RoundKernel contains the main logic of server handling messages from workers. One iteration has multiple round +// kernels to represent the process. They receive and parse messages from the server communication module. After +// handling these messages, round kernels allocate response data and send it back. + +// For example, the main process of federated learning is: +// startFLJob round->updateModel round->getModel round. +class RoundKernel : virtual public CPUKernel { + public: + RoundKernel(); + virtual ~RoundKernel() = default; + + // RoundKernel doesn't use InitKernel method of base class CPUKernel to initialize. So implementation of this + // inherited method is empty. + void InitKernel(const CNodePtr &kernel_node) override {} + + // Initialize RoundKernel with threshold_count which means that for every iteration, this round needs threshold_count + // messages. + virtual void InitKernel(size_t threshold_count) = 0; + + // Launch the round kernel logic to handle the message passed by the communication module. + virtual bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) = 0; + + // The callbacks when first message and last message for this round kernel is received. + // These methods is called by class DistributedCountService and triggered by leader server(Rank 0). + // virtual void OnFirstCountEvent(std::shared_ptr message); + // virtual void OnLastCnt(std::shared_ptr message); + + // Some rounds could be stateful in a iteration. Reset method resets the status of this round. + virtual bool Reset() = 0; + + // The counter event handlers for DistributedCountService. + virtual void OnFirstCountEvent(const std::shared_ptr &message); + virtual void OnLastCountEvent(const std::shared_ptr &message); + + // Called when this round is finished. This round timer's Stop method will be called. + void StopTimer(); + + // Called after this iteration(including all rounds) is finished. All rounds' Reset method will + // be called. + void FinishIteration(); + + // Release the response data allocated inside the round kernel. + // Server framework must call this after the response data is sent back. + void Release(AddressPtr addr_ptr); + + // Set round kernel name, which could be used in round kernel's methods. + void set_name(const std::string &name); + + // Set callbacks to be called under certain triggered conditions. + void set_stop_timer_cb(StopTimerCb timer_stopper); + void set_finish_iteration_cb(FinishIterCb finish_iteration_cb); + + protected: + // Generating response data of this round. The data is allocated on the heap to ensure it's not released before sent + // back to worker. + void GenerateOutput(const std::vector &outputs, void *data, size_t len); + + // Round kernel's name. + std::string name_; + + // The current received message count for this round in this iteration. + size_t current_count_; + + // The required received message count for this round in one iteration. + size_t required_count_; + + // The reason causes the error in this round kernel. + std::string error_reason_; + + StopTimerCb stop_timer_cb_; + FinishIterCb finish_iteration_cb_; + + // Members below are used for allocating and releasing response data on the heap. + + // To ensure the performance, we use another thread to release data on the heap. So the operation on the data should + // be threadsafe. + std::thread release_thread_; + + // Data needs to be released and its mutex; + std::mutex release_mtx_; + std::queue heap_data_to_release_; + std::mutex heap_data_mtx_; + std::unordered_map> heap_data_; +}; +} // namespace kernel +} // namespace server +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_ diff --git a/mindspore/ccsrc/ps/server/kernel/round/round_kernel_factory.cc b/mindspore/ccsrc/ps/server/kernel/round/round_kernel_factory.cc new file mode 100644 index 00000000000..d669b74d879 --- /dev/null +++ b/mindspore/ccsrc/ps/server/kernel/round/round_kernel_factory.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/server/kernel/round/round_kernel_factory.h" + +namespace mindspore { +namespace ps { +namespace server { +namespace kernel { +RoundKernelFactory &RoundKernelFactory::GetInstance() { + static RoundKernelFactory instance; + return instance; +} + +void RoundKernelFactory::Register(const std::string &name, RoundKernelCreator &&creator) { + name_to_creator_map_[name] = creator; +} + +std::shared_ptr RoundKernelFactory::Create(const std::string &name) { + if (name_to_creator_map_.count(name) == 0) { + MS_LOG(ERROR) << "Round kernel " << name << " is not registered."; + return nullptr; + } + auto kernel = name_to_creator_map_[name](); + kernel->set_name(name); + return kernel; +} +} // namespace kernel +} // namespace server +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/server/kernel/round/round_kernel_factory.h b/mindspore/ccsrc/ps/server/kernel/round/round_kernel_factory.h new file mode 100644 index 00000000000..45c4226ad8e --- /dev/null +++ b/mindspore/ccsrc/ps/server/kernel/round/round_kernel_factory.h @@ -0,0 +1,62 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_ +#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_ + +#include +#include +#include +#include +#include "ps/server/common.h" +#include "ps/server/kernel/round/round_kernel.h" + +namespace mindspore { +namespace ps { +namespace server { +namespace kernel { +using RoundKernelCreator = std::function()>; +// Kernel factory of round kernels. +class RoundKernelFactory { + public: + static RoundKernelFactory &GetInstance(); + void Register(const std::string &name, RoundKernelCreator &&creator); + std::shared_ptr Create(const std::string &name); + + private: + RoundKernelFactory() = default; + ~RoundKernelFactory() = default; + RoundKernelFactory(const RoundKernelFactory &) = delete; + RoundKernelFactory &operator=(const RoundKernelFactory &) = delete; + + std::unordered_map name_to_creator_map_; +}; + +class RoundKernelRegister { + public: + RoundKernelRegister(const std::string &name, RoundKernelCreator &&creator) { + RoundKernelFactory::GetInstance().Register(name, std::move(creator)); + } +}; + +#define REG_ROUND_KERNEL(NAME, CLASS) \ + static_assert(std::is_base_of::value, " must be base of RoundKernel"); \ + static const RoundKernelRegister g_##NAME##_round_kernel_reg(#NAME, []() { return std::make_shared(); }); +} // namespace kernel +} // namespace server +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_ diff --git a/mindspore/ccsrc/ps/server/kernel/round/start_fl_job_kernel.cc b/mindspore/ccsrc/ps/server/kernel/round/start_fl_job_kernel.cc new file mode 100644 index 00000000000..3c277bdf45d --- /dev/null +++ b/mindspore/ccsrc/ps/server/kernel/round/start_fl_job_kernel.cc @@ -0,0 +1,192 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/server/kernel/round/start_fl_job_kernel.h" +#include +#include +#include +#include + +namespace mindspore { +namespace ps { +namespace server { +namespace kernel { +void StartFLJobKernel::InitKernel(size_t) { + if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) { + iteration_time_window_ = LocalMetaStore::GetInstance().value(kCtxTotalTimeoutDuration); + } + + executor_ = &Executor::GetInstance(); + MS_EXCEPTION_IF_NULL(executor_); + if (!executor_->initialized()) { + MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline."; + return; + } + + PBMetadata devices_metas; + DistributedMetadataStore::GetInstance().RegisterMetadata(kCtxDeviceMetas, devices_metas); + return; +} + +bool StartFLJobKernel::Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) { + MS_LOG(INFO) << "Launching StartFLJobKernel kernel."; + if (inputs.size() != 1 || outputs.size() != 1) { + MS_LOG(ERROR) << "inputs or outputs size is invalid."; + return false; + } + void *req_data = inputs[0]->addr; + const std::shared_ptr &fbb = std::make_shared(); + if (fbb == nullptr || req_data == nullptr) { + MS_LOG(ERROR) << "FBBuilder builder or req_data is nullptr."; + return false; + } + + if (ReachThresholdForStartFLJob(fbb)) { + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return false; + } + + const schema::RequestFLJob *start_fl_job_req = flatbuffers::GetRoot(req_data); + DeviceMeta device_meta = CreateDeviceMetadata(start_fl_job_req); + if (!ReadyForStartFLJob(fbb, device_meta)) { + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return false; + } + // If calling ReportCount before ReadyForStartFLJob, the result will be inconsistent if the device is not selected. + if (!CountForStartFLJob(fbb, start_fl_job_req)) { + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return false; + } + + StartFLJob(fbb, device_meta); + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; +} + +bool StartFLJobKernel::Reset() { + MS_LOG(INFO) << "Starting fl job kernel reset!"; + StopTimer(); + DistributedCountService::GetInstance().ResetCounter(name_); + DistributedMetadataStore::GetInstance().ResetMetadata(kCtxDeviceMetas); + return true; +} + +bool StartFLJobKernel::ReachThresholdForStartFLJob(const std::shared_ptr &fbb) { + if (DistributedCountService::GetInstance().CountReachThreshold(name_)) { + std::string reason = "Current amount for startFLJob has reached the threshold. Please startFLJob later."; + BuildStartFLJobRsp(fbb, schema::ResponseCode_OutOfTime, reason, false, + std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_)); + MS_LOG(ERROR) << reason; + return true; + } + return false; +} + +DeviceMeta StartFLJobKernel::CreateDeviceMetadata(const schema::RequestFLJob *start_fl_job_req) { + std::string fl_name = start_fl_job_req->fl_name()->str(); + std::string fl_id = start_fl_job_req->fl_id()->str(); + int data_size = start_fl_job_req->data_size(); + MS_LOG(INFO) << "DeviceMeta fl_name:" << fl_name << ", fl_id:" << fl_id << ", data_size:" << data_size; + + DeviceMeta device_meta; + device_meta.set_fl_name(fl_name); + device_meta.set_fl_id(fl_id); + device_meta.set_data_size(data_size); + return device_meta; +} + +bool StartFLJobKernel::ReadyForStartFLJob(const std::shared_ptr &fbb, const DeviceMeta &device_meta) { + bool ret = true; + std::string reason = ""; + if (device_meta.data_size() < 1) { + reason = "FL job data size is not enough."; + ret = false; + } + if (!ret) { + BuildStartFLJobRsp(fbb, schema::ResponseCode_NotSelected, reason, false, + std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_)); + MS_LOG(ERROR) << reason; + } + return ret; +} + +bool StartFLJobKernel::CountForStartFLJob(const std::shared_ptr &fbb, + const schema::RequestFLJob *start_fl_job_req) { + if (!DistributedCountService::GetInstance().Count(name_, start_fl_job_req->fl_id()->str())) { + std::string reason = "startFLJob counting failed."; + BuildStartFLJobRsp(fbb, schema::ResponseCode_OutOfTime, reason, false, + std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_)); + MS_LOG(ERROR) << reason; + return false; + } + return true; +} + +void StartFLJobKernel::StartFLJob(const std::shared_ptr &fbb, const DeviceMeta &device_meta) { + PBMetadata metadata; + *metadata.mutable_device_meta() = device_meta; + DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxDeviceMetas, metadata); + + std::map feature_maps = executor_->GetModel(); + BuildStartFLJobRsp(fbb, schema::ResponseCode_SUCCEED, "success", true, + std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_), feature_maps); + return; +} + +void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr &fbb, const schema::ResponseCode retcode, + const std::string &reason, const bool is_selected, + const std::string &next_req_time, + std::map feature_maps) { + auto fbs_reason = fbb->CreateString(reason); + auto fbs_next_req_time = fbb->CreateString(next_req_time); + auto fbs_fl_name = fbb->CreateString(PSContext::instance()->fl_name()); + + schema::FLPlanBuilder fl_plan_builder(*(fbb.get())); + fl_plan_builder.add_fl_name(fbs_fl_name); + fl_plan_builder.add_iterations(PSContext::instance()->fl_iteration_num()); + fl_plan_builder.add_epochs(PSContext::instance()->client_epoch_num()); + fl_plan_builder.add_mini_batch(PSContext::instance()->client_batch_size()); + auto fbs_fl_plan = fl_plan_builder.Finish(); + + std::vector> fbs_feature_maps; + for (auto feature_map : feature_maps) { + auto fbs_weight_fullname = fbb->CreateString(feature_map.first); + auto fbs_weight_data = + fbb->CreateVector(reinterpret_cast(feature_map.second->addr), feature_map.second->size / sizeof(float)); + auto fbs_feature_map = schema::CreateFeatureMap(*(fbb.get()), fbs_weight_fullname, fbs_weight_data); + fbs_feature_maps.push_back(fbs_feature_map); + } + auto fbs_feature_maps_vector = fbb->CreateVector(fbs_feature_maps); + + schema::ResponseFLJobBuilder rsp_fl_job_builder(*(fbb.get())); + rsp_fl_job_builder.add_retcode(retcode); + rsp_fl_job_builder.add_reason(fbs_reason); + rsp_fl_job_builder.add_iteration(LocalMetaStore::GetInstance().curr_iter_num()); + rsp_fl_job_builder.add_is_selected(is_selected); + rsp_fl_job_builder.add_next_req_time(fbs_next_req_time); + rsp_fl_job_builder.add_fl_plan_config(fbs_fl_plan); + rsp_fl_job_builder.add_feature_map(fbs_feature_maps_vector); + auto rsp_fl_job = rsp_fl_job_builder.Finish(); + fbb->Finish(rsp_fl_job); + return; +} + +REG_ROUND_KERNEL(startFLJob, StartFLJobKernel) +} // namespace kernel +} // namespace server +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/server/kernel/round/start_fl_job_kernel.h b/mindspore/ccsrc/ps/server/kernel/round/start_fl_job_kernel.h new file mode 100644 index 00000000000..7e4b2ef47c2 --- /dev/null +++ b/mindspore/ccsrc/ps/server/kernel/round/start_fl_job_kernel.h @@ -0,0 +1,74 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_START_FL_JOB_KERNEL_H_ +#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_START_FL_JOB_KERNEL_H_ + +#include +#include +#include +#include +#include "ps/server/common.h" +#include "ps/server/executor.h" +#include "ps/server/kernel/round/round_kernel.h" +#include "ps/server/kernel/round/round_kernel_factory.h" + +namespace mindspore { +namespace ps { +namespace server { +namespace kernel { +class StartFLJobKernel : public RoundKernel { + public: + StartFLJobKernel() = default; + ~StartFLJobKernel() override = default; + + void InitKernel(size_t threshold_count) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + bool Reset() override; + + private: + // Returns whether the startFLJob count of this iteration has reached the threshold. + bool ReachThresholdForStartFLJob(const std::shared_ptr &fbb); + + // The metadata of device will be stored and queried in updateModel round. + DeviceMeta CreateDeviceMetadata(const schema::RequestFLJob *start_fl_job_req); + + // Returns whether the request is valid for startFLJob.For now, the condition is simple. We will add more conditions + // to device in later versions. + bool ReadyForStartFLJob(const std::shared_ptr &fbb, const DeviceMeta &device_meta); + + // Distributed count service counts for startFLJob. + bool CountForStartFLJob(const std::shared_ptr &fbb, const schema::RequestFLJob *start_fl_job_req); + + void StartFLJob(const std::shared_ptr &fbb, const DeviceMeta &device_meta); + + // Build response for startFLJob round no matter success or failure. + void BuildStartFLJobRsp(const std::shared_ptr &fbb, const schema::ResponseCode retcode, + const std::string &reason, const bool is_selected, const std::string &next_req_time, + std::map feature_maps = {}); + + // The executor is for getting the initial model for startFLJob request. + Executor *executor_; + + // The time window of one iteration. + size_t iteration_time_window_; +}; +} // namespace kernel +} // namespace server +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_START_FL_JOB_KERNEL_H_ diff --git a/mindspore/ccsrc/ps/server/local_meta_storage.cc b/mindspore/ccsrc/ps/server/local_meta_store.cc similarity index 79% rename from mindspore/ccsrc/ps/server/local_meta_storage.cc rename to mindspore/ccsrc/ps/server/local_meta_store.cc index 72cc08d4f46..aab873d8989 100644 --- a/mindspore/ccsrc/ps/server/local_meta_storage.cc +++ b/mindspore/ccsrc/ps/server/local_meta_store.cc @@ -14,30 +14,29 @@ * limitations under the License. */ -#include "ps/server/local_meta_storage.h" -#include +#include "ps/server/local_meta_store.h" namespace mindspore { namespace ps { namespace server { -void LocalMetaStorage::remove_value(const std::string &name) { +void LocalMetaStore::remove_value(const std::string &name) { std::unique_lock lock(mtx_); if (key_to_meta_.count(name) != 0) { key_to_meta_.erase(key_to_meta_.find(name)); } } -bool LocalMetaStorage::has_value(const std::string &name) { +bool LocalMetaStore::has_value(const std::string &name) { std::unique_lock lock(mtx_); return key_to_meta_.count(name) != 0; } -void LocalMetaStorage::set_curr_iter_num(size_t num) { +void LocalMetaStore::set_curr_iter_num(size_t num) { std::unique_lock lock(mtx_); curr_iter_num_ = num; } -const size_t LocalMetaStorage::curr_iter_num() { +const size_t LocalMetaStore::curr_iter_num() { std::unique_lock lock(mtx_); return curr_iter_num_; } diff --git a/mindspore/ccsrc/ps/server/local_meta_storage.h b/mindspore/ccsrc/ps/server/local_meta_store.h similarity index 77% rename from mindspore/ccsrc/ps/server/local_meta_storage.h rename to mindspore/ccsrc/ps/server/local_meta_store.h index 3a7467e6589..5cbcc238a1b 100644 --- a/mindspore/ccsrc/ps/server/local_meta_storage.h +++ b/mindspore/ccsrc/ps/server/local_meta_store.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORAGE_H_ -#define MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORAGE_H_ +#ifndef MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORE_H_ +#define MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORE_H_ #include #include @@ -26,13 +26,13 @@ namespace mindspore { namespace ps { namespace server { -// LocalMetaStorage class is used for metadata storage of this server process. +// LocalMetaStore class is used for metadata storage of this server process. // For example, the current iteration number, time windows for round kernels, etc. -// LocalMetaStorage is threadsafe. -class LocalMetaStorage { +// LocalMetaStore is threadsafe. +class LocalMetaStore { public: - static LocalMetaStorage &GetInstance() { - static LocalMetaStorage instance; + static LocalMetaStore &GetInstance() { + static LocalMetaStore instance; return instance; } @@ -43,7 +43,7 @@ class LocalMetaStorage { } template - const T &value(const std::string &name) { + T value(const std::string &name) { std::unique_lock lock(mtx_); try { T value = std::any_cast(key_to_meta_[name]); @@ -71,10 +71,10 @@ class LocalMetaStorage { const size_t curr_iter_num(); private: - LocalMetaStorage() = default; - ~LocalMetaStorage() = default; - LocalMetaStorage(const LocalMetaStorage &) = delete; - LocalMetaStorage &operator=(const LocalMetaStorage &) = delete; + LocalMetaStore() = default; + ~LocalMetaStore() = default; + LocalMetaStore(const LocalMetaStore &) = delete; + LocalMetaStore &operator=(const LocalMetaStore &) = delete; // key_to_meta_ stores metadata with key-value format. std::unordered_map key_to_meta_; @@ -85,4 +85,4 @@ class LocalMetaStorage { } // namespace server } // namespace ps } // namespace mindspore -#endif // MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORAGE_H_ +#endif // MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORE_H_ diff --git a/mindspore/ccsrc/ps/server/model_store.cc b/mindspore/ccsrc/ps/server/model_store.cc new file mode 100644 index 00000000000..62c23f50ad4 --- /dev/null +++ b/mindspore/ccsrc/ps/server/model_store.cc @@ -0,0 +1,144 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/server/model_store.h" +#include +#include +#include +#include "ps/server/executor.h" + +namespace mindspore { +namespace ps { +namespace server { +void ModelStore::Init(uint32_t max_count) { + if (!Executor::GetInstance().initialized()) { + MS_LOG(EXCEPTION) << "Server's executor must be initialized before model storage."; + return; + } + + max_model_count_ = max_count; + iteration_to_model_[kInitIterationNum] = AssignNewModelMemory(); + model_size_ = ComputeModelSize(); +} + +bool ModelStore::StoreModelByIterNum(size_t iteration, const std::map &new_model) { + if (iteration_to_model_.count(iteration) != 0) { + MS_LOG(WARNING) << "Model for iteration " << iteration << " is already stored"; + return false; + } + if (new_model.empty()) { + MS_LOG(ERROR) << "Model feature map is empty."; + return false; + } + + std::shared_ptr memory_register; + if (iteration_to_model_.size() < max_model_count_) { + // If iteration_to_model_.size() is not max_model_count_, need to assign new memory for the model. + memory_register = AssignNewModelMemory(); + if (memory_register == nullptr) { + MS_LOG(ERROR) << "Memory for the new model is nullptr."; + return false; + } + + iteration_to_model_[iteration] = memory_register; + } else { + // If iteration_to_model_ size is already max_model_count_, we need to replace earliest model with the newest model. + memory_register = iteration_to_model_.begin()->second; + if (memory_register == nullptr) { + MS_LOG(ERROR) << "Earliest model is nullptr."; + return false; + } + iteration_to_model_.erase(iteration_to_model_.begin()); + } + + // Copy new model data to the the stored model. + auto &stored_model = memory_register->addresses(); + for (const auto &weight : new_model) { + const std::string &weight_name = weight.first; + if (stored_model.count(weight_name) != 0) { + MS_LOG(ERROR) << "The stored model has no weight " << weight_name; + continue; + } + + void *dst_addr = stored_model[weight_name]->addr; + size_t dst_size = stored_model[weight_name]->size; + void *src_addr = weight.second->addr; + size_t src_size = weight.second->size; + int ret = memcpy_s(dst_addr, dst_size, src_addr, src_size); + if (ret != 0) { + MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; + return false; + } + } + iteration_to_model_[iteration] = memory_register; + return true; +} + +std::map ModelStore::GetModelByIterNum(size_t iteration) { + std::map model = {}; + if (iteration_to_model_.count(iteration) == 0) { + MS_LOG(ERROR) << "Model for iteration " << iteration << " is not stored."; + return model; + } + model = iteration_to_model_[iteration]->addresses(); + return model; +} + +const std::map> &ModelStore::iteration_to_model() const { + return iteration_to_model_; +} + +size_t ModelStore::model_size() const { return model_size_; } + +std::shared_ptr ModelStore::AssignNewModelMemory() { + std::map model = Executor::GetInstance().GetModel(); + if (model.empty()) { + MS_LOG(EXCEPTION) << "Model feature map is empty."; + return nullptr; + } + + // Assign new memory for the model. + std::shared_ptr memory_register = std::make_shared(); + for (const auto &weight : model) { + const std::string weight_name = weight.first; + size_t weight_size = weight.second->size; + auto weight_data = std::make_unique(weight_size); + if (weight_data == nullptr) { + MS_LOG(EXCEPTION) << "Assign memory for weight failed."; + return nullptr; + } + + memory_register->RegisterArray(weight_name, &weight_data, weight_size); + } + return memory_register; +} + +size_t ModelStore::ComputeModelSize() { + if (iteration_to_model_.empty()) { + MS_LOG(EXCEPTION) << "Calculating model size failed: model for iteration 0 is not stored yet. "; + return 0; + } + + const auto &model = iteration_to_model_[kInitIterationNum]; + MS_EXCEPTION_IF_NULL(model); + size_t model_size = std::accumulate(model->addresses().begin(), model->addresses().end(), static_cast(0), + [](size_t s, const auto &weight) { return s + weight.second->size; }); + MS_LOG(INFO) << "Model size in byte is " << model_size; + return model_size; +} +} // namespace server +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/server/model_store.h b/mindspore/ccsrc/ps/server/model_store.h new file mode 100644 index 00000000000..bbaf7ba295b --- /dev/null +++ b/mindspore/ccsrc/ps/server/model_store.h @@ -0,0 +1,78 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_SERVER_MODEL_STORE_H_ +#define MINDSPORE_CCSRC_PS_SERVER_MODEL_STORE_H_ + +#include +#include +#include +#include "ps/server/common.h" +#include "ps/server/memory_register.h" +#include "ps/server/executor.h" + +namespace mindspore { +namespace ps { +namespace server { +// The initial iteration number is 0 in server. +constexpr size_t kInitIterationNum = 0; + +// Server framework use ModelStore to store and query models. +// ModelStore stores multiple models because worker could get models of the previous iterations. +class ModelStore { + public: + static ModelStore &GetInstance() { + static ModelStore instance; + return instance; + } + + // Initialize ModelStore with max count of models need to be stored. + void Init(uint32_t max_count = 3); + + // Store the model of the given iteration. The model is acquired from Executor. If the current model count is already + // max_model_count_, the earliest model will be replaced. + bool StoreModelByIterNum(size_t iteration, const std::map &model); + + // Get model of the given iteration. + std::map GetModelByIterNum(size_t iteration); + + // Returns all models stored in ModelStore. + const std::map> &iteration_to_model() const; + + // Returns the model size, which could be calculated at the initializing phase. + size_t model_size() const; + + private: + ModelStore() : max_model_count_(0), model_size_(0), iteration_to_model_({}) {} + ~ModelStore() = default; + ModelStore(const ModelStore &) = delete; + ModelStore &operator=(const ModelStore &) = delete; + + // To store multiple models, new memory must assigned. The max memory size assigned for models is max_model_count_ * + // model_size_. + std::shared_ptr AssignNewModelMemory(); + + // Calculate the model size. This method should be called after iteration_to_model_ is initialized. + size_t ComputeModelSize(); + + size_t max_model_count_; + size_t model_size_; + std::map> iteration_to_model_; +}; +} // namespace server +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_SERVER_MODEL_STORE_H_ diff --git a/mindspore/ccsrc/ps/server/parameter_aggregator.cc b/mindspore/ccsrc/ps/server/parameter_aggregator.cc index fbe07634484..a7ffcfdc0fd 100644 --- a/mindspore/ccsrc/ps/server/parameter_aggregator.cc +++ b/mindspore/ccsrc/ps/server/parameter_aggregator.cc @@ -25,15 +25,15 @@ namespace mindspore { namespace ps { namespace server { -bool ParameterAggregator::Init(const CNodePtr &cnode, size_t required_count) { +bool ParameterAggregator::Init(const CNodePtr &cnode, size_t threshold_count) { MS_EXCEPTION_IF_NULL(cnode); memory_register_ = std::make_shared(); MS_EXCEPTION_IF_NULL(memory_register_); - required_push_count_ = required_count; + required_push_count_ = threshold_count; // The required_pull_count_ is the count for Pull, which should be the same as required_push_count_. // required_pull_count_ normally used in parameter server training mode. - required_pull_count_ = required_count; + required_pull_count_ = threshold_count; MS_LOG(DEBUG) << "Start initializing kernels for " << AnfAlgo::GetCNodeName(cnode); InitAggregationKernels(cnode); diff --git a/mindspore/ccsrc/ps/server/parameter_aggregator.h b/mindspore/ccsrc/ps/server/parameter_aggregator.h index 8344f75f123..8e51b3ae860 100644 --- a/mindspore/ccsrc/ps/server/parameter_aggregator.h +++ b/mindspore/ccsrc/ps/server/parameter_aggregator.h @@ -61,8 +61,8 @@ class ParameterAggregator { ~ParameterAggregator() = default; // Initialize ParameterAggregator with a cnode. This cnode is normally a optimizer kernel for now. - // The parameter required_count helps ParameterAggregator to judge the current status if it's stateful. - bool Init(const CNodePtr &cnode, size_t required_count = 0); + // The parameter threshold_count helps ParameterAggregator to judge the current status if it's stateful. + bool Init(const CNodePtr &cnode, size_t threshold_count = 0); // Update old data stored in ParameterAggregator with new data. // The data could have many meanings: weights, gradients, learning_rate, momentum, etc. diff --git a/mindspore/ccsrc/ps/server/round.cc b/mindspore/ccsrc/ps/server/round.cc new file mode 100644 index 00000000000..313a0024c52 --- /dev/null +++ b/mindspore/ccsrc/ps/server/round.cc @@ -0,0 +1,139 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/server/round.h" +#include +#include + +namespace mindspore { +namespace ps { +namespace server { +Round::Round(const std::string &name, bool check_timeout, size_t time_window, bool check_count, size_t threshold_count) + : name_(name), + check_timeout_(check_timeout), + time_window_(time_window), + check_count_(check_count), + threshold_count_(threshold_count) {} + +void Round::Initialize(const std::shared_ptr &communicator, TimeOutCb timeout_cb, + FinishIterCb finish_iteration_cb) { + MS_EXCEPTION_IF_NULL(communicator); + communicator_ = communicator; + + // Register callback for round kernel. + communicator_->RegisterMsgCallBack( + name_, [&](std::shared_ptr message) { LaunchRoundKernel(message); }); + + // Callback when the iteration is finished. + finish_iteration_cb_ = [this, finish_iteration_cb](void) -> void { + MS_LOG(INFO) << "Round " << name_ << " finished! Proceed to next iteration."; + finish_iteration_cb(); + }; + + // Callback for finalizing the server. This can only be called once. + finalize_cb_ = [&](void) -> void { communicator_->Stop(); }; + + if (check_timeout_) { + iter_timer_ = std::make_shared(); + + // 1.Set the timeout callback for the timer. + iter_timer_->SetTimeOutCallBack([this, timeout_cb](void) -> void { + MS_LOG(INFO) << "Round " << name_ << " timeout! Proceed to next iteration."; + timeout_cb(); + }); + + // 2.Stopping timer callback which will be set to the round kernel. + stop_timer_cb_ = [&](void) -> void { + MS_LOG(INFO) << "Round " << name_ << " kernel stops its timer."; + iter_timer_->Stop(); + }; + } + + // Set counter event callbacks for this round if the round kernel is stateful. + if (check_count_) { + auto first_count_handler = std::bind(&Round::OnFirstCountEvent, this, std::placeholders::_1); + auto last_count_handler = std::bind(&Round::OnLastCountEvent, this, std::placeholders::_1); + DistributedCountService::GetInstance().RegisterCounter(name_, threshold_count_, + {first_count_handler, last_count_handler}); + } +} + +void Round::BindRoundKernel(const std::shared_ptr &kernel) { + MS_EXCEPTION_IF_NULL(kernel); + kernel_ = kernel; + kernel_->set_stop_timer_cb(stop_timer_cb_); + kernel_->set_finish_iteration_cb(finish_iteration_cb_); + return; +} + +void Round::LaunchRoundKernel(const std::shared_ptr &message) { + if (message == nullptr) { + MS_LOG(ERROR) << "Message is nullptr."; + return; + } + + AddressPtr input = std::make_shared
(); + AddressPtr output = std::make_shared
(); + input->addr = message->data(); + input->size = message->len(); + bool ret = kernel_->Launch({input}, {}, {output}); + if (output->size == 0) { + std::string reason = "The output of the round " + name_ + " is empty."; + MS_LOG(WARNING) << reason; + communicator_->SendResponse(reason.c_str(), reason.size(), message); + return; + } + + // Must send response back no matter what value Launch method returns. + if (!ret) { + MS_LOG(WARNING) << "Launching round kernel of round " << name_ << " failed."; + } + communicator_->SendResponse(output->addr, output->size, message); + kernel_->Release(output); + return; +} + +void Round::Reset() { kernel_->Reset(); } + +const std::string &Round::name() const { return name_; } + +size_t Round::threshold_count() const { return threshold_count_; } + +size_t Round::time_window() const { return time_window_; } + +void Round::OnFirstCountEvent(const std::shared_ptr &) { + MS_LOG(INFO) << "Round " << name_ << " first count event is triggered."; + // The timer starts only after the first count event is triggered by DistributedCountService. + if (check_timeout_) { + iter_timer_->Start(std::chrono::milliseconds(time_window_)); + } + return; +} + +void Round::OnLastCountEvent(const std::shared_ptr &message) { + MS_LOG(INFO) << "Round " << name_ << " last count event is triggered."; + // Same as the first count event, the timer must be stopped by DistributedCountService. + if (check_timeout_) { + iter_timer_->Stop(); + } + + // Some kernels override the OnLastCountEvent method. + kernel_->OnLastCountEvent(message); + return; +} +} // namespace server +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/server/round.h b/mindspore/ccsrc/ps/server/round.h new file mode 100644 index 00000000000..e927e1904fe --- /dev/null +++ b/mindspore/ccsrc/ps/server/round.h @@ -0,0 +1,95 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_SERVER_ROUND_H_ +#define MINDSPORE_CCSRC_PS_SERVER_ROUND_H_ + +#include +#include +#include "ps/core/communicator/communicator_base.h" +#include "ps/server/common.h" +#include "ps/server/iteration_timer.h" +#include "ps/server/distributed_count_service.h" +#include "ps/server/kernel/round/round_kernel.h" + +namespace mindspore { +namespace ps { +namespace server { +// Round helps server to handle network round messages and launch round kernels. One iteration in server consists of +// multiple rounds like startFLJob, updateModel, Push, Pull, etc. Some round kernels may be stateful because of counting +// and timing. So Round helps register counter and timer so that the round kernels only need to focus on the logic. +class Round { + public: + explicit Round(const std::string &name, bool check_timeout = true, size_t time_window = 3000, + bool check_count = false, size_t threshold_count = 8); + ~Round() = default; + + void Initialize(const std::shared_ptr &communicator, TimeOutCb timeout_cb, + FinishIterCb finish_iteration_cb); + + // Bind a round kernel to this Round. This method should be called after Initialize. + void BindRoundKernel(const std::shared_ptr &kernel); + + // This method is the callback which will be set to the communicator and called after the corresponding round message + // is sent to the server. + void LaunchRoundKernel(const std::shared_ptr &message); + + // Round needs to be reset after each iteration is finished or its timer expires. + void Reset(); + + const std::string &name() const; + size_t threshold_count() const; + size_t time_window() const; + + private: + // The callbacks which will be set to DistributedCounterService. + void OnFirstCountEvent(const std::shared_ptr &message); + void OnLastCountEvent(const std::shared_ptr &message); + + std::string name_; + + // Whether this round needs to use timer. Most rounds in federated learning with mobile devices scenario need to set + // check_timeout_ to true. + bool check_timeout_; + + // The time window duration for this round when check_timeout_ is set to true. + size_t time_window_; + + // If check_count_ is true, it means the round has to do counting for every round message and the first/last count + // event will be triggered. + bool check_count_; + + // The threshold count for this round when check_count_ is set to true. The logic of this round has to check whether + // the round message count has reached threshold_count_. + size_t threshold_count_; + + std::shared_ptr communicator_; + + // The round kernel for this Round. + std::shared_ptr kernel_; + + // Some rounds may need timer to eliminate the long tail effect. + std::shared_ptr iter_timer_; + + // The callbacks which will be set to the round kernel. + StopTimerCb stop_timer_cb_; + FinishIterCb finish_iteration_cb_; + FinalizeCb finalize_cb_; +}; +} // namespace server +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_SERVER_ROUND_H_ diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index ae6c92541f7..d7aed3f39ba 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -31,7 +31,7 @@ #include "utils/utils.h" #include "frontend/parallel/context.h" #include "debug/env_config_parser.h" -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) #include "ps/ps_cache/ps_cache_manager.h" #endif @@ -307,7 +307,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { } need_alloc_nodes.push_back(item); } -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) bool ps_cache_check = false; #endif for (auto &item : need_alloc_nodes) { @@ -320,7 +320,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { continue; } DeviceAddressPtr device_address = nullptr; -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) const std::string ¶m_name = item->fullname_with_scope(); if (ps::ps_cache_instance.IsHashTable(param_name)) { MS_LOG(INFO) << "Parameter(" << param_name << ")" @@ -1038,7 +1038,7 @@ DeviceAddressPtr KernelRuntime::AssignSingleOpLaunchMemory(size_t size, const st return device_address; } -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph, AnfNodePtr *const first_cache_input_index, size_t *const first_cache_size) { diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.h b/mindspore/ccsrc/runtime/device/kernel_runtime.h index 802518892d0..75d2794cad9 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.h @@ -142,7 +142,7 @@ class KernelRuntime { void RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value, session::KernelGraph *graph); void AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, size_t output_idx); DeviceAddressPtr PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index); -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) void GetFirstPSEmbeddingCache(const session::KernelGraph *graph, AnfNodePtr *const first_cache_input_index, size_t *const first_cache_size); void CheckIfSupportPSEmbeddingCache(const session::KernelGraph *graph); diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc b/mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc index 90b43369c97..f347e1936b7 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc @@ -16,14 +16,14 @@ #include "runtime/device/kernel_runtime_manager.h" #include "utils/log_adapter.h" -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) #include "ps/ps_cache/ps_cache_manager.h" #endif namespace mindspore { namespace device { void KernelRuntimeManager::ClearRuntimeResource() { -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) { ps::ps_cache_instance.SyncEmbeddingTable(); } @@ -125,7 +125,7 @@ void KernelRuntimeManager::ReleaseKernelRuntime(const std::string &device_name, if (runtime == nullptr) { return; } -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#if (ENABLE_CPU && !_WIN32) if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) { ps::ps_cache_instance.SyncEmbeddingTable(); } diff --git a/mindspore/schema/cipher.fbs b/mindspore/schema/cipher.fbs new file mode 100644 index 00000000000..bd27cfb7aa2 --- /dev/null +++ b/mindspore/schema/cipher.fbs @@ -0,0 +1,123 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +namespace mindspore.schema; + +table CipherPublicParams { + t:int; + p:[ubyte]; + g:int; + prime:[ubyte]; + dp_eps:float; + dp_delta:float; + dp_norm_clip:float; + encrypt_type:int; +} + +table ClientPublicKeys { + fl_id:string; + c_pk:[ubyte]; + s_pk: [ubyte]; +} + +table ClientShare { + fl_id:string; + share:[ubyte]; + index:int; +} + +table RequestExchangeKeys{ + fl_id:string; + c_pk:[ubyte]; + s_pk:[ubyte]; + iteration:int; + timestamp:string; +} + +table ResponseExchangeKeys{ + retcode:int; + reason:string; + next_req_time:string; + iteration:int; +} + +table GetExchangeKeys{ + fl_id:string; + iteration:int; + timestamp:string; +} + +table ReturnExchangeKeys{ + retcode:int; + iteration:int; + remote_publickeys:[ClientPublicKeys]; + next_req_time:string; +} + +table RequestShareSecrets{ + fl_id:string; + encrypted_shares:[ClientShare]; + iteration:int; + timestamp:string; +} + +table ResponseShareSecrets{ + retcode:int; + reason:string; + next_req_time:string; + iteration:int; +} + +table GetShareSecrets{ + fl_id:string; + iteration:int; + timestamp:string; +} + +table ReturnShareSecrets{ + retcode:int; + iteration:int; + encrypted_shares: [ClientShare]; + next_req_time:string; +} + +table GetClientList{ + fl_id:string; + iteration:int; + timestamp:string; +} + +table ReturnClientList{ + retcode:int; + reason:string; + clients:[string]; + iteration:int; + next_req_time:string; +} + +table SendReconstructSecret{ + fl_id:string; + reconstruct_secret_shares:[ClientShare]; + iteration:int; + timestamp:string; +} + +table ReconstructSecret{ + retcode:int; + reason:string; + iteration:int; + next_req_time:string; +} diff --git a/mindspore/schema/fl_job.fbs b/mindspore/schema/fl_job.fbs new file mode 100644 index 00000000000..4c9839b6158 --- /dev/null +++ b/mindspore/schema/fl_job.fbs @@ -0,0 +1,159 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +include "cipher.fbs"; +namespace mindspore.schema; + +file_identifier "FLJ0"; +file_extension "fl"; + +enum ResponseCode: int { + SUCCEED=200, + SucNotReady=201, + RepeatRequest=202, + SucNotMatch=204, + OutOfTime=300, + NotSelected=301, + RequestError=400, + SystemError=500 +} + +enum AggregationType:byte {FedAvg=0, FedAdam = 1, FedAdagrag=2, FedMeta=3, qffl=4} +enum Metrics:byte {accuracy = 0, precision = 1, recall = 2, AUC = 3,f1=4, fbeta=5} +enum EarlyStopType:byte {loss_diff = 0, loss_abs = 1, weight_diff = 2} + +table Aggregation { + type:AggregationType; + weights:[float]; +} + +table EarlyStop { + early_stop_type:EarlyStopType; + weight:float; + rounds:int; + } + +table FeatureMap{ + weight_fullname:string; + data:[float]; +} +table RequestFLJob{ + fl_name:string; + fl_id:string; + iteration:int; + data_size:int; + timestamp:string; +} +table ResponseFLJob { + retcode:int; + reason:string; + iteration:int; + is_selected:bool = false; + next_req_time:string; + fl_plan_config:FLPlan; + feature_map:[FeatureMap]; + timestamp:string; +} + +table FLPlan { + fl_name:string; + iterations:int; + epochs:int; + early_stop:EarlyStop; + mini_batch:int; + shuffle:bool = false; + lr:float; + aggregation:Aggregation; + metrics:[Metrics]; + cipher:CipherPublicParams; +} + +table RequestUpdateModel{ + fl_name:string; + fl_id:string; + iteration:int; + feature_map:[FeatureMap]; + timestamp:string; +} +table ResponseUpdateModel{ + retcode:int; + reason:string; + feature_map:[FeatureMap]; + next_req_time:string; + timestamp:string; +} + +table RequestAsyncUpdateModel{ + fl_name:string; + fl_id:string; + iteration:int; + data_size:int; + feature_map:[FeatureMap]; +} +table ResponseAsyncUpdateModel{ + retcode:int; + reason:string; + iteration:int; +} + +table RequestOverwriteWeightsByKey{ + iteration:int; + feature_map:[FeatureMap]; +} + +table ResponseOverwriteWeightsByKey{ + retcode:int; + reason:string; +} + +table RequestGetModel{ + fl_name:string; + iteration:int; + timestamp:string; +} +table ResponseGetModel{ + retcode:int; + reason:string; + iteration:int; + feature_map:[FeatureMap]; + timestamp:string; +} + +table RequestAsyncGetModel{ + fl_name:string; + iteration:int; +} +table ResponseAsyncGetModel{ + retcode:int; + reason:string; + iteration:int; + feature_map:[FeatureMap]; +} + +table RequestGetWeightsByKey{ + iteration:int; + weight_names:[string]; +} +table ResponseGetWeightsByKey{ + retcode:int; + reason:string; + feature_map:[FeatureMap]; +} + +// FeatureMapList refers to the whole trained model. +table FeatureMapList { + feature_map:[FeatureMap]; +}