Add server code part2

This commit is contained in:
ZPaC 2021-04-26 14:36:19 +08:00
parent cb6e055736
commit 12f95b51f4
57 changed files with 2846 additions and 107 deletions

View File

@ -35,7 +35,7 @@ function(ms_build_flatbuffers source_schema_files
set(total_schema_dirs -I ${schema_dir} ${total_schema_dirs}) set(total_schema_dirs -I ${schema_dir} ${total_schema_dirs})
endforeach() endforeach()
foreach(schema ${source_schema_files}) foreach(schema IN LISTS ${source_schema_files})
get_filename_component(filename ${schema} NAME_WE) get_filename_component(filename ${schema} NAME_WE)
if(NOT ${generated_output_dir} STREQUAL "") if(NOT ${generated_output_dir} STREQUAL "")
set(generated_file ${generated_output_dir}/${filename}_generated.h) set(generated_file ${generated_output_dir}/${filename}_generated.h)

View File

@ -212,7 +212,7 @@ if(ENABLE_GPU)
) )
endif() endif()
if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) if(ENABLE_CPU AND NOT WIN32)
install( install(
TARGETS ps_cache TARGETS ps_cache
DESTINATION ${INSTALL_LIB_DIR} DESTINATION ${INSTALL_LIB_DIR}

View File

@ -373,7 +373,7 @@ elseif(CMAKE_SYSTEM_NAME MATCHES "Darwin")
target_link_libraries(mindspore mindspore_gvar) target_link_libraries(mindspore mindspore_gvar)
target_link_libraries(_c_expression PRIVATE -Wl,-force_load mindspore mindspore_core -Wl,-noall_load) target_link_libraries(_c_expression PRIVATE -Wl,-force_load mindspore mindspore_core -Wl,-noall_load)
else() else()
if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) if(ENABLE_CPU AND NOT WIN32)
target_link_libraries(mindspore proto_input mindspore::protobuf target_link_libraries(mindspore proto_input mindspore::protobuf
mindspore::event mindspore::event_pthreads mindspore::event_openssl mindspore::json) mindspore::event mindspore::event_pthreads mindspore::event_openssl mindspore::json)
target_link_libraries(mindspore -Wl,--no-as-needed mindspore::event_core ps_cache) target_link_libraries(mindspore -Wl,--no-as-needed mindspore::event_core ps_cache)

View File

@ -75,7 +75,7 @@ if(ENABLE_CPU)
endif() endif()
endif() endif()
if(NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) if(NOT ENABLE_CPU OR WIN32)
list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/apply_momentum_ps_kernel.cc") list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/apply_momentum_ps_kernel.cc")
list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/embedding_look_up_proxy_kernel.cc") list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/embedding_look_up_proxy_kernel.cc")
list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/embedding_look_up_ps_kernel.cc") list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/embedding_look_up_ps_kernel.cc")

View File

@ -421,7 +421,7 @@ void AscendSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_gra
size = abstract::ShapeSize(shape_tmp) * abstract::TypeIdSize(tensor->data_type()); size = abstract::ShapeSize(shape_tmp) * abstract::TypeIdSize(tensor->data_type());
} }
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) { if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
const std::string &param_name = input_node->fullname_with_scope(); const std::string &param_name = input_node->fullname_with_scope();
if (ps::ps_cache_instance.IsHashTable(param_name)) { if (ps::ps_cache_instance.IsHashTable(param_name)) {
continue; continue;

View File

@ -33,7 +33,7 @@
#include "debug/anf_ir_dump.h" #include "debug/anf_ir_dump.h"
#include "debug/dump_proto.h" #include "debug/dump_proto.h"
#include "debug/data_dump/dump_json_parser.h" #include "debug/data_dump/dump_json_parser.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
#include "ps/util.h" #include "ps/util.h"
#include "ps/ps_context.h" #include "ps/ps_context.h"
#endif #endif
@ -74,7 +74,7 @@ void CPUSession::Reorder(std::vector<CNodePtr> *node_list) { AnfAlgo::ReorderPos
void CPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) { void CPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
auto optimizer = std::make_shared<opt::GraphOptimizer>(); auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>(); auto pm = std::make_shared<opt::PassManager>();
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode && ps::PSContext::instance()->is_ps_mode()) { if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode && ps::PSContext::instance()->is_ps_mode()) {
@ -174,7 +174,7 @@ void CPUSession::PreExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_grap
MS_LOG(INFO) << "Bind input output address"; MS_LOG(INFO) << "Bind input output address";
runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs); runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs);
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
InitPSParamAndOptim(kernel_graph, inputs); InitPSParamAndOptim(kernel_graph, inputs);
#endif #endif
} }

View File

@ -21,7 +21,7 @@
#include "utils/comm_manager.h" #include "utils/comm_manager.h"
#include "utils/scoped_long_running.h" #include "utils/scoped_long_running.h"
#include "pybind_api/ir/tensor_py.h" #include "pybind_api/ir/tensor_py.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
#include "ps/ps_cache/ps_cache_manager.h" #include "ps/ps_cache/ps_cache_manager.h"
#endif #endif

View File

@ -43,7 +43,7 @@
#include "debug/common.h" #include "debug/common.h"
#include "utils/trace_base.h" #include "utils/trace_base.h"
#include "frontend/parallel/context.h" #include "frontend/parallel/context.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
#include "ps/ps_cache/ps_cache_manager.h" #include "ps/ps_cache/ps_cache_manager.h"
#include "ps/constants.h" #include "ps/constants.h"
#include "ps/util.h" #include "ps/util.h"
@ -2357,7 +2357,7 @@ void SessionBasic::DumpGraph(const std::shared_ptr<KernelGraph> &kernel_graph) {
#endif #endif
} }
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
void SessionBasic::InitPsWorker(const KernelGraphPtr &kernel_graph) { void SessionBasic::InitPsWorker(const KernelGraphPtr &kernel_graph) {
if (!ps::PSContext::instance()->is_worker()) { if (!ps::PSContext::instance()->is_worker()) {
return; return;

View File

@ -244,7 +244,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
std::vector<uint32_t> GetAllReduceSplitIndex(); std::vector<uint32_t> GetAllReduceSplitIndex();
virtual std::string GetCommWorldGroup() { return std::string(); } virtual std::string GetCommWorldGroup() { return std::string(); }
void DumpGraph(const std::shared_ptr<KernelGraph> &kernel_graph); void DumpGraph(const std::shared_ptr<KernelGraph> &kernel_graph);
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
void CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) const; void CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) const;
void GetBatchElements(const AnfNodePtr &kernel_node) const; void GetBatchElements(const AnfNodePtr &kernel_node) const;
void InitPsWorker(const KernelGraphPtr &kernel_graph); void InitPsWorker(const KernelGraphPtr &kernel_graph);
@ -263,7 +263,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
#if !defined(_WIN32) && !defined(_WIN64) #if !defined(_WIN32) && !defined(_WIN64)
std::shared_ptr<Debugger> debugger_; std::shared_ptr<Debugger> debugger_;
#endif #endif
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
bool initialized_ps_cache_{false}; bool initialized_ps_cache_{false};
#endif #endif
}; };

View File

@ -25,7 +25,7 @@
#include "frontend/parallel/device_matrix.h" #include "frontend/parallel/device_matrix.h"
#include "frontend/parallel/graph_util/generate_graph.h" #include "frontend/parallel/graph_util/generate_graph.h"
#include "frontend/parallel/context.h" #include "frontend/parallel/context.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
#include "ps/ps_cache/ps_cache_manager.h" #include "ps/ps_cache/ps_cache_manager.h"
#include "utils/ms_context.h" #include "utils/ms_context.h"
#endif #endif
@ -160,7 +160,7 @@ Status GatherPInfo::GetAttrs() {
if (std::find(inputs_shape_[1].begin(), inputs_shape_[1].end(), -1) != inputs_shape_[1].end()) { if (std::find(inputs_shape_[1].begin(), inputs_shape_[1].end(), -1) != inputs_shape_[1].end()) {
dynamic_shape_indices_ = true; dynamic_shape_indices_ = true;
} }
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode(); std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
@ -617,7 +617,7 @@ Status GatherPInfo::InferBias() {
rank = rank % (params_strategy[0] * params_strategy[1]); rank = rank % (params_strategy[0] * params_strategy[1]);
} }
} }
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
if (ps::PsDataPrefetch::GetInstance().cache_enable()) { if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
bias_ = static_cast<int64_t>(ps::PsCacheManager::GetInstance().cache_indices_lower_bound()); bias_ = static_cast<int64_t>(ps::PsCacheManager::GetInstance().cache_indices_lower_bound());
return SUCCESS; return SUCCESS;

View File

@ -28,7 +28,7 @@
#include "frontend/parallel/strategy.h" #include "frontend/parallel/strategy.h"
#include "frontend/parallel/context.h" #include "frontend/parallel/context.h"
#include "frontend/parallel/tensor_layout/tensor_redistribution.h" #include "frontend/parallel/tensor_layout/tensor_redistribution.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
#include "ps/ps_cache/ps_cache_manager.h" #include "ps/ps_cache/ps_cache_manager.h"
#endif #endif
@ -192,7 +192,7 @@ Status UniqueInfo::GenerateStrategies(int64_t stage_id) {
return SUCCESS; return SUCCESS;
} }
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
Status UniqueInfo::ComputeReplaceGraph(const CNodePtr &cnode) { Status UniqueInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
GenerateGraph gen_g = GenerateGraph(); GenerateGraph gen_g = GenerateGraph();
if (gen_g.Init(cnode) != SUCCESS) { if (gen_g.Init(cnode) != SUCCESS) {
@ -230,7 +230,7 @@ Status UniqueInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
#endif #endif
ReplaceGraphPtr UniqueInfo::replace_graph(const CNodePtr &cnode) { ReplaceGraphPtr UniqueInfo::replace_graph(const CNodePtr &cnode) {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
if (ps::PsDataPrefetch::GetInstance().cache_enable()) { if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
auto inputs = cnode->inputs(); auto inputs = cnode->inputs();
if (inputs.empty()) { if (inputs.empty()) {

View File

@ -51,7 +51,7 @@ class UniqueInfo : public OperatorInfo {
Status InferMirrorOps() override; Status InferMirrorOps() override;
Status InferForwardCommunication() override { return SUCCESS; } Status InferForwardCommunication() override { return SUCCESS; }
Status InferAsLossDivisor() override { return SUCCESS; } Status InferAsLossDivisor() override { return SUCCESS; }
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
Status ComputeReplaceGraph(const CNodePtr &cnode); Status ComputeReplaceGraph(const CNodePtr &cnode);
#endif #endif

View File

@ -47,14 +47,14 @@
#include "ir/anf.h" #include "ir/anf.h"
#include "ir/param_info.h" #include "ir/param_info.h"
#include "ir/tensor.h" #include "ir/tensor.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
#include "ps/util.h" #include "ps/util.h"
#endif #endif
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) { bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
if (ps::Util::IsRoleOfPServer() || ps::Util::IsRoleOfScheduler()) { if (ps::Util::IsRoleOfPServer() || ps::Util::IsRoleOfScheduler()) {
return false; return false;
} }

View File

@ -46,7 +46,7 @@
#include "utils/ms_context.h" #include "utils/ms_context.h"
#include "utils/symbolic.h" #include "utils/symbolic.h"
#include "mindspore/core/utils/parallel_node_check.h" #include "mindspore/core/utils/parallel_node_check.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
#include "ps/util.h" #include "ps/util.h"
#include "ps/ps_context.h" #include "ps/ps_context.h"
#endif #endif
@ -3553,7 +3553,7 @@ static void HandleFullySplitParameters(const FuncGraphPtr &root) {
} }
bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) { bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) { if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) {
return false; return false;
} }

View File

@ -295,7 +295,7 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Windows")
target_link_libraries(_c_dataengine PRIVATE _c_mindrecord ${MINDRECORD_LINK_OBJECT} mindspore::sqlite) target_link_libraries(_c_dataengine PRIVATE _c_mindrecord ${MINDRECORD_LINK_OBJECT} mindspore::sqlite)
else() else()
target_link_libraries(_c_dataengine PRIVATE _c_mindrecord) target_link_libraries(_c_dataengine PRIVATE _c_mindrecord)
if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) if(ENABLE_CPU AND NOT WIN32)
if(${ENABLE_IBVERBS} STREQUAL "ON") if(${ENABLE_IBVERBS} STREQUAL "ON")
target_link_libraries(_c_dataengine PRIVATE ibverbs rdmacm) target_link_libraries(_c_dataengine PRIVATE ibverbs rdmacm)
endif() endif()

View File

@ -1,7 +1,8 @@
add_subdirectory(perf EXCLUDE_FROM_ALL) add_subdirectory(perf EXCLUDE_FROM_ALL)
include_directories("${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache") include_directories("${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache")
set(MD_FLATBUFFER_OU "${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache") set(MD_FLATBUFFER_OU "${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache")
ms_build_flatbuffers("de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${MD_FLATBUFFER_OU}) set(FBS_FILES de_tensor.fbs)
ms_build_flatbuffers(FBS_FILES ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${MD_FLATBUFFER_OU})
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)

View File

@ -43,7 +43,7 @@
#include "vm/transform.h" #include "vm/transform.h"
#include "parse/python_adapter.h" #include "parse/python_adapter.h"
#include "frontend/optimizer/py_pass_manager.h" #include "frontend/optimizer/py_pass_manager.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
#include "ps/parameter_server.h" #include "ps/parameter_server.h"
#include "ps/scheduler.h" #include "ps/scheduler.h"
#include "ps/worker.h" #include "ps/worker.h"
@ -606,7 +606,7 @@ bool ExecuteAction(const ResourcePtr &res) {
return true; return true;
} }
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
bool StartPSWorkerAction(const ResourcePtr &res) { bool StartPSWorkerAction(const ResourcePtr &res) {
ps::Worker::GetInstance().Run(); ps::Worker::GetInstance().Run();
return true; return true;
@ -782,7 +782,7 @@ std::vector<ActionItem> VmPipeline() {
actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction)); actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction));
actions.emplace_back(std::make_pair("validate", ValidateAction)); actions.emplace_back(std::make_pair("validate", ValidateAction));
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
if (ps::PSContext::instance()->is_worker()) { if (ps::PSContext::instance()->is_worker()) {
actions.emplace_back(std::make_pair("worker", StartPSWorkerAction)); actions.emplace_back(std::make_pair("worker", StartPSWorkerAction));
} }
@ -796,7 +796,7 @@ std::vector<ActionItem> VmPipeline() {
return actions; return actions;
} }
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
std::vector<ActionItem> PServerPipeline() { std::vector<ActionItem> PServerPipeline() {
auto actions = CommonPipeline(); auto actions = CommonPipeline();
actions.emplace_back(std::make_pair("optimize", VmOptimizeAction)); actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));

View File

@ -34,7 +34,7 @@
#else #else
#include "runtime/device/gpu/distribution/collective_fake_init.h" #include "runtime/device/gpu/distribution/collective_fake_init.h"
#endif #endif
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
#include "ps/util.h" #include "ps/util.h"
#endif #endif
#include "ps/ps_context.h" #include "ps/ps_context.h"

View File

@ -42,7 +42,7 @@
#include "pipeline/jit/pipeline_split.h" #include "pipeline/jit/pipeline_split.h"
#include "pipeline/jit/static_analysis/auto_monad.h" #include "pipeline/jit/static_analysis/auto_monad.h"
#include "frontend/optimizer/irpass/gradient_eliminate.h" #include "frontend/optimizer/irpass/gradient_eliminate.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
#include "ps/util.h" #include "ps/util.h"
#include "ps/ps_context.h" #include "ps/ps_context.h"
#endif #endif
@ -407,7 +407,7 @@ bool AddRecomputationPass(const ResourcePtr &res) {
} }
bool AddCacheEmbeddingPass(const ResourcePtr &res) { bool AddCacheEmbeddingPass(const ResourcePtr &res) {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
if (ps::PSContext::instance()->is_ps_mode()) { if (ps::PSContext::instance()->is_ps_mode()) {
return true; return true;
} }

View File

@ -49,7 +49,7 @@
#include "utils/shape_utils.h" #include "utils/shape_utils.h"
#include "utils/info.h" #include "utils/info.h"
#include "load_mindir/load_model.h" #include "load_mindir/load_model.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
#include "ps/constants.h" #include "ps/constants.h"
#include "ps/util.h" #include "ps/util.h"
#include "ps/worker.h" #include "ps/worker.h"
@ -528,7 +528,7 @@ std::vector<ActionItem> GetPipeline(const ResourcePtr &resource, const std::stri
std::string backend = MsContext::GetInstance()->backend_policy(); std::string backend = MsContext::GetInstance()->backend_policy();
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
if (ps::PSContext::instance()->is_server()) { if (ps::PSContext::instance()->is_server()) {
resource->results()[kBackend] = compile::CreateBackend(); resource->results()[kBackend] = compile::CreateBackend();
return PServerPipeline(); return PServerPipeline();
@ -961,7 +961,7 @@ bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t ba
bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size, bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size,
const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes, const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
const std::vector<int64_t> &input_indexes, bool need_run) { const std::vector<int64_t> &input_indexes, bool need_run) {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
if ((ps::PSContext::instance()->is_ps_mode()) && (!ps::PSContext::instance()->is_worker())) { if ((ps::PSContext::instance()->is_ps_mode()) && (!ps::PSContext::instance()->is_worker())) {
return true; return true;
} }
@ -1027,7 +1027,7 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc
} }
ConfigManager::GetInstance().set_iter_num(size); ConfigManager::GetInstance().set_iter_num(size);
// PS cache does not support loop sink. // PS cache does not support loop sink.
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) { if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) {
ps::PsDataPrefetch::GetInstance().CreateDataChannel(queue_name, LongToSize(size)); ps::PsDataPrefetch::GetInstance().CreateDataChannel(queue_name, LongToSize(size));
ConfigManager::GetInstance().set_iter_num(1); ConfigManager::GetInstance().set_iter_num(1);
@ -1150,7 +1150,7 @@ void FinalizeBackend() {
void ClearResAtexit() { void ClearResAtexit() {
MS_LOG(DEBUG) << "Pipeline clear all resource"; MS_LOG(DEBUG) << "Pipeline clear all resource";
pynative::ClearPyNativeSession(); pynative::ClearPyNativeSession();
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
if (ps::PSContext::instance()->is_ps_mode() && ps::PSContext::instance()->is_worker()) { if (ps::PSContext::instance()->is_ps_mode() && ps::PSContext::instance()->is_worker()) {
if (ps::PsDataPrefetch::GetInstance().cache_enable()) { if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
ps::ps_cache_instance.Finalize(); ps::ps_cache_instance.Finalize();

View File

@ -1,6 +1,13 @@
file(GLOB_RECURSE _PS_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") file(GLOB_RECURSE _PS_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
if(NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) set(SERVER_FLATBUFFER_OUTPUT "${CMAKE_BINARY_DIR}/schema")
set(FBS_FILES
${CMAKE_CURRENT_SOURCE_DIR}/../../schema/cipher.fbs
${CMAKE_CURRENT_SOURCE_DIR}/../../schema/fl_job.fbs
)
ms_build_flatbuffers(FBS_FILES ${CMAKE_CURRENT_SOURCE_DIR}../../schema generated_fbs_files ${SERVER_FLATBUFFER_OUTPUT})
if(NOT ENABLE_CPU OR WIN32)
list(REMOVE_ITEM _PS_SRC_FILES "optimizer_info_builder.cc") list(REMOVE_ITEM _PS_SRC_FILES "optimizer_info_builder.cc")
list(REMOVE_ITEM _PS_SRC_FILES "optimizer_info.cc") list(REMOVE_ITEM _PS_SRC_FILES "optimizer_info.cc")
list(REMOVE_ITEM _PS_SRC_FILES "scheduler.cc") list(REMOVE_ITEM _PS_SRC_FILES "scheduler.cc")
@ -12,11 +19,6 @@ if(NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)))
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_client.cc") list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_client.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_message_handler.cc") list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_message_handler.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_server.cc") list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_server.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/communicator_base.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/http_communicator.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_communicator.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/http_msg_handler.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_msg_handler.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/node.cc") list(REMOVE_ITEM _PS_SRC_FILES "core/node.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/node_manager.cc") list(REMOVE_ITEM _PS_SRC_FILES "core/node_manager.cc")
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc") list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc")
@ -39,18 +41,32 @@ if(NOT ENABLE_GPU)
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/gpu/gpu_ps_cache.cc") list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/gpu/gpu_ps_cache.cc")
endif() endif()
if(WIN32 OR NOT ENABLE_CPU) if(NOT ENABLE_CPU OR WIN32)
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/communicator_base.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/http_communicator.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_communicator.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/http_msg_handler.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_msg_handler.cc")
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/apply_momentum_kernel.cc") list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/apply_momentum_kernel.cc")
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/aggregation_kernel_factory.cc") list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/aggregation_kernel_factory.cc")
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/dense_grad_accum_kernel.cc") list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/dense_grad_accum_kernel.cc")
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/optimizer_kernel_factory.cc") list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/optimizer_kernel_factory.cc")
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/round_kernel_factory.cc")
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/round_kernel.cc")
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/start_fl_job_kernel.cc")
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/params_info.cc") list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/params_info.cc")
list(REMOVE_ITEM _PS_SRC_FILES "server/consistent_hash_ring.cc") list(REMOVE_ITEM _PS_SRC_FILES "server/consistent_hash_ring.cc")
list(REMOVE_ITEM _PS_SRC_FILES "server/iteration_timer.cc") list(REMOVE_ITEM _PS_SRC_FILES "server/iteration_timer.cc")
list(REMOVE_ITEM _PS_SRC_FILES "server/local_meta_storage.cc") list(REMOVE_ITEM _PS_SRC_FILES "server/local_meta_store.cc")
list(REMOVE_ITEM _PS_SRC_FILES "server/memory_register.cc") list(REMOVE_ITEM _PS_SRC_FILES "server/memory_register.cc")
list(REMOVE_ITEM _PS_SRC_FILES "server/parameter_aggregator.cc") list(REMOVE_ITEM _PS_SRC_FILES "server/parameter_aggregator.cc")
list(REMOVE_ITEM _PS_SRC_FILES "server/executor.cc") list(REMOVE_ITEM _PS_SRC_FILES "server/executor.cc")
list(REMOVE_ITEM _PS_SRC_FILES "server/collective_ops_impl.cc")
list(REMOVE_ITEM _PS_SRC_FILES "server/distributed_count_service.cc")
list(REMOVE_ITEM _PS_SRC_FILES "server/distributed_metadata_store.cc")
list(REMOVE_ITEM _PS_SRC_FILES "server/iteration.cc")
list(REMOVE_ITEM _PS_SRC_FILES "server/model_store.cc")
list(REMOVE_ITEM _PS_SRC_FILES "server/round.cc")
endif() endif()
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_prefetch.cc") list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_prefetch.cc")
@ -59,3 +75,5 @@ add_subdirectory(ps_cache)
set_property(SOURCE ${_PS_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS) set_property(SOURCE ${_PS_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS)
add_library(_mindspore_ps_obj OBJECT ${_PS_SRC_FILES}) add_library(_mindspore_ps_obj OBJECT ${_PS_SRC_FILES})
add_dependencies(_mindspore_ps_obj generated_fbs_files)
target_link_libraries(_mindspore_ps_obj mindspore::flatbuffers)

View File

@ -34,13 +34,26 @@
namespace mindspore { namespace mindspore {
namespace ps { namespace ps {
namespace core { namespace core {
enum class TcpUserCommand { kPush, kPull, kCount, kReachThreshold, kResetCount, kGetValue, kPutValue, kCounterEvent }; enum class TcpUserCommand {
kPush,
kPull,
kCount,
kReachThreshold,
kResetCount,
kGetMetadata,
kUpdateMetadata,
kCounterEvent
};
const std::unordered_map<TcpUserCommand, std::string> kUserCommandToMsgType = { const std::unordered_map<TcpUserCommand, std::string> kUserCommandToMsgType = {
{TcpUserCommand::kPush, "push"}, {TcpUserCommand::kPull, "pull"}, {TcpUserCommand::kPush, "push"},
{TcpUserCommand::kCount, "count"}, {TcpUserCommand::kReachThreshold, "reachThreshold"}, {TcpUserCommand::kPull, "pull"},
{TcpUserCommand::kResetCount, "resetCnt"}, {TcpUserCommand::kGetValue, "getValue"}, {TcpUserCommand::kCount, "count"},
{TcpUserCommand::kPutValue, "putValue"}, {TcpUserCommand::kCounterEvent, "counterEvent"}, {TcpUserCommand::kReachThreshold, "countReachThreshold"},
{TcpUserCommand::kResetCount, "resetCnt"},
{TcpUserCommand::kGetMetadata, "getMetadata"},
{TcpUserCommand::kUpdateMetadata, "updateMetadata"},
{TcpUserCommand::kCounterEvent, "counterEvent"},
}; };
class TcpCommunicator : public CommunicatorBase { class TcpCommunicator : public CommunicatorBase {

View File

@ -0,0 +1,155 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
syntax = "proto3";
package mindspore.ps;
message CollectiveData {
bytes data = 1;
}
message CountRequest {
string name = 1;
string id = 2;
}
message CountResponse {
bool result = 1;
string reason = 2;
}
message CountReachThresholdRequest {
string name = 1;
}
message CountReachThresholdResponse {
bool is_enough = 1;
}
message ResetCounterRequest {
string name = 1;
}
message UpdateMetadataRequest {
string name = 1;
bytes value = 2;
}
message GetMetadataRequest {
string name = 1;
}
message GetMetadataResponse {
bytes value = 1;
}
enum CounterEventType {
FIRST_CNT = 0;
LAST_CNT = 1;
}
message CounterEvent {
CounterEventType type = 1;
string name = 2;
bytes data = 3;
}
message FLId {
string fl_id = 1;
}
message UpdateModelClientList {
repeated string fl_id = 1;
}
message DeviceMeta {
string fl_name = 1;
string fl_id = 2;
uint64 data_size = 3;
}
message FLIdToDeviceMeta {
map<string, DeviceMeta> fl_id_to_meta = 1;
}
message UpdateModelThreshold {
uint64 threshold = 1;
}
message ClientShares {
map<string, SharesPb> client_secret_shares = 1;
}
message PairClientShares {
string fl_id = 1;
SharesPb client_shares = 2;
}
message ClientKeys {
map<string, KeysPb> client_keys = 1;
}
message ClientNoises {
OneClientNoises one_client_noises = 1;
}
message PairClientKeys {
string fl_id = 1;
KeysPb client_keys = 2;
}
message OneClientNoises {
repeated float noise = 1;
}
message ClientShareStr {
string fl_id = 1;
bytes share = 2; // todo: verify the correctness
int32 index = 3;
}
message SharesPb {
repeated ClientShareStr clientsharestrs = 1;
}
message KeysPb {
repeated bytes key = 1;
}
message PBMetadata {
oneof value {
DeviceMeta device_meta = 1;
FLIdToDeviceMeta device_metas = 2;
FLId fl_id = 3;
UpdateModelClientList client_list = 4;
UpdateModelThreshold update_model_threshold = 5;
PairClientShares pair_client_shares = 6;
ClientShares client_shares = 7;
PairClientKeys pair_client_keys = 8;
ClientKeys client_keys = 9;
OneClientNoises one_client_noises = 10;
ClientNoises client_noises = 11;
}
}
message PBMetadataWithName {
string name = 1;
PBMetadata metadata = 2;
}

View File

@ -1,4 +1,4 @@
if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) if(ENABLE_CPU AND NOT WIN32)
file(GLOB_RECURSE _PS_CACHE_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ps_data/*.cc") file(GLOB_RECURSE _PS_CACHE_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ps_data/*.cc")
set_property(SOURCE ${_PS_CACHE_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS) set_property(SOURCE ${_PS_CACHE_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS)
add_library(ps_cache SHARED ${_PS_CACHE_SRC_FILES}) add_library(ps_cache SHARED ${_PS_CACHE_SRC_FILES})

View File

@ -18,7 +18,7 @@
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "utils/ms_utils.h" #include "utils/ms_utils.h"
#include "backend/kernel_compiler/kernel.h" #include "backend/kernel_compiler/kernel.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
#include "ps/ps_cache/ps_cache_manager.h" #include "ps/ps_cache/ps_cache_manager.h"
#include "ps/ps_cache/ps_data/ps_data_prefetch.h" #include "ps/ps_cache/ps_data/ps_data_prefetch.h"
#endif #endif
@ -68,7 +68,7 @@ void PSContext::Reset() {
is_worker_ = false; is_worker_ = false;
is_pserver_ = false; is_pserver_ = false;
is_sched_ = false; is_sched_ = false;
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
if (ps::PsDataPrefetch::GetInstance().cache_enable()) { if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
ps_cache_instance.Finalize(); ps_cache_instance.Finalize();
set_cache_enable(false); set_cache_enable(false);
@ -108,46 +108,62 @@ int PSContext::ps_rank_id() const { return rank_id_; }
void PSContext::InsertHashTableSize(const std::string &param_name, size_t cache_vocab_size, size_t embedding_size, void PSContext::InsertHashTableSize(const std::string &param_name, size_t cache_vocab_size, size_t embedding_size,
size_t vocab_size) const { size_t vocab_size) const {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
ps_cache_instance.InsertHashTableSize(param_name, cache_vocab_size, embedding_size, vocab_size); ps_cache_instance.InsertHashTableSize(param_name, cache_vocab_size, embedding_size, vocab_size);
#endif #endif
} }
void PSContext::ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name, void PSContext::ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name,
size_t cache_vocab_size, size_t embedding_size) const { size_t cache_vocab_size, size_t embedding_size) const {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
ps_cache_instance.ReInsertHashTableSize(new_param_name, cur_param_name, cache_vocab_size, embedding_size); ps_cache_instance.ReInsertHashTableSize(new_param_name, cur_param_name, cache_vocab_size, embedding_size);
#endif #endif
} }
void PSContext::InsertWeightInitInfo(const std::string &param_name, size_t global_seed, size_t op_seed) const { void PSContext::InsertWeightInitInfo(const std::string &param_name, size_t global_seed, size_t op_seed) const {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
ps_cache_instance.InsertWeightInitInfo(param_name, global_seed, op_seed); ps_cache_instance.InsertWeightInitInfo(param_name, global_seed, op_seed);
#endif #endif
} }
void PSContext::InsertAccumuInitInfo(const std::string &param_name, float init_val) const { void PSContext::InsertAccumuInitInfo(const std::string &param_name, float init_val) const {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
ps_cache_instance.InsertAccumuInitInfo(param_name, init_val); ps_cache_instance.InsertAccumuInitInfo(param_name, init_val);
#endif #endif
} }
void PSContext::CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) const { void PSContext::CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) const {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
ps_cache_instance.CloneHashTable(dest_param_name, src_param_name); ps_cache_instance.CloneHashTable(dest_param_name, src_param_name);
#endif #endif
} }
void PSContext::set_cache_enable(bool cache_enable) const { void PSContext::set_cache_enable(bool cache_enable) const {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
PsDataPrefetch::GetInstance().set_cache_enable(cache_enable); PsDataPrefetch::GetInstance().set_cache_enable(cache_enable);
#endif #endif
} }
void PSContext::set_rank_id(int rank_id) const { void PSContext::set_rank_id(int rank_id) const {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
ps_cache_instance.set_rank_id(rank_id); ps_cache_instance.set_rank_id(rank_id);
#endif #endif
} }
void PSContext::set_fl_name(const std::string &fl_name) { fl_name_ = fl_name; }
const std::string &PSContext::fl_name() const { return fl_name_; }
void PSContext::set_fl_iteration_num(uint64_t fl_iteration_num) { fl_iteration_num_ = fl_iteration_num; }
uint64_t PSContext::fl_iteration_num() const { return fl_iteration_num_; }
void PSContext::set_client_epoch_num(uint64_t client_epoch_num) { client_epoch_num_ = client_epoch_num; }
uint64_t PSContext::client_epoch_num() const { return client_epoch_num_; }
void PSContext::set_client_batch_size(uint64_t client_batch_size) { client_batch_size_ = client_batch_size; }
uint64_t PSContext::client_batch_size() const { return client_batch_size_; }
} // namespace ps } // namespace ps
} // namespace mindspore } // namespace mindspore

View File

@ -60,6 +60,19 @@ class PSContext {
void set_cache_enable(bool cache_enable) const; void set_cache_enable(bool cache_enable) const;
void set_rank_id(int rank_id) const; void set_rank_id(int rank_id) const;
// Setter and getter for federated learning.
void set_fl_name(const std::string &fl_name);
const std::string &fl_name() const;
void set_fl_iteration_num(uint64_t fl_iteration_num);
uint64_t fl_iteration_num() const;
void set_client_epoch_num(uint64_t client_epoch_num);
uint64_t client_epoch_num() const;
void set_client_batch_size(uint64_t client_batch_size);
uint64_t client_batch_size() const;
private: private:
PSContext() PSContext()
: ps_enabled_(false), : ps_enabled_(false),
@ -80,6 +93,12 @@ class PSContext {
uint32_t server_num_; uint32_t server_num_;
std::string scheduler_host_; std::string scheduler_host_;
uint16_t scheduler_port_; uint16_t scheduler_port_;
// Members for federated learning.
std::string fl_name_;
uint64_t fl_iteration_num_;
uint64_t client_epoch_num_;
uint64_t client_batch_size_;
}; };
} // namespace ps } // namespace ps
} // namespace mindspore } // namespace mindspore

View File

@ -0,0 +1,223 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ps/server/collective_ops_impl.h"
namespace mindspore {
namespace ps {
namespace server {
void CollectiveOpsImpl::Initialize(const std::shared_ptr<core::ServerNode> &server_node) {
MS_EXCEPTION_IF_NULL(server_node);
server_node_ = server_node;
local_rank_ = server_node_->rank_id();
server_num_ = PSContext::instance()->initial_server_num();
return;
}
template <typename T>
bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size_t count) {
int ret = memcpy_s(recvbuff, count * sizeof(T), sendbuff, count * sizeof(T));
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
return false;
}
uint32_t rank_size = server_num_;
uint32_t local_rank_ = server_node_->rank_id();
size_t chunk_size = count / rank_size;
size_t remainder_size = count % rank_size;
std::vector<size_t> chunk_sizes(rank_size, chunk_size);
// The rest of the data should be assigned to each chunk.
for (size_t i = 0; i < remainder_size; i++) {
chunk_sizes[i]++;
}
// Store offsets to get every data chunk's address.
std::vector<size_t> chunk_offset;
for (size_t i = 0; i < rank_size; i++) {
size_t ofs =
std::accumulate(chunk_sizes.begin(), chunk_sizes.begin() + i, static_cast<size_t>(0), std::plus<size_t>());
chunk_offset.push_back(ofs);
}
T *output_buff = reinterpret_cast<T *>(recvbuff);
uint32_t send_to_rank = (local_rank_ + 1) % rank_size;
uint32_t recv_from_rank = (local_rank_ - 1 + rank_size) % rank_size;
MS_LOG(DEBUG) << "AllReduce count:" << count << ", rank_size:" << rank_size << ", local_rank_:" << local_rank_
<< ", chunk_size:" << chunk_size << ", remainder_size:" << remainder_size
<< ", chunk_sizes:" << chunk_sizes << ", send_to_rank:" << send_to_rank
<< ", recv_from_rank:" << recv_from_rank;
// Ring ReduceScatter.
MS_LOG(DEBUG) << "Start Ring ReduceScatter.";
std::unique_ptr<T[]> tmp_recv_chunk = std::make_unique<T[]>(chunk_sizes[0]);
for (size_t i = 0; i < rank_size - 1; i++) {
// Step 1: Async send data to next rank.
size_t send_chunk_index = (local_rank_ - i + rank_size) % rank_size;
T *send_chunk = output_buff + chunk_offset[send_chunk_index];
auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, send_to_rank, send_chunk,
chunk_sizes[send_chunk_index] * sizeof(T));
// Step 2: Async receive data to next rank and wait until it's done.
size_t recv_chunk_index = (local_rank_ - i - 1 + rank_size) % rank_size;
T *recv_chunk = output_buff + chunk_offset[recv_chunk_index];
MS_LOG(DEBUG) << "Ring ReduceScatter send_to_rank:" << send_to_rank << ", recv_from_rank:" << recv_from_rank
<< ", send count:" << chunk_sizes[send_chunk_index]
<< ", recv count:" << chunk_sizes[recv_chunk_index] << ", iteration:" << i;
std::shared_ptr<std::vector<unsigned char>> recv_str;
auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, recv_from_rank, &recv_str);
if (!server_node_->CollectiveWait(recv_req_id, 1)) {
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
return false;
}
memcpy_s(tmp_recv_chunk.get(), chunk_sizes[recv_chunk_index] * sizeof(T), recv_str->data(), recv_str->size());
// Step 3: Reduce the data so we can overlap the time cost of send.
for (size_t j = 0; j < chunk_sizes[recv_chunk_index]; j++) {
recv_chunk[j] += tmp_recv_chunk[j];
}
// Step 4: Wait until send is done.
if (!server_node_->Wait(send_req_id, 1)) {
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
return false;
}
}
MS_LOG(DEBUG) << "End Ring ReduceScatter.";
// Ring AllGather.
MS_LOG(DEBUG) << "Start Ring AllGather.";
for (size_t i = 0; i < rank_size - 1; i++) {
size_t send_chunk_index = (local_rank_ - i + 1 + rank_size) % rank_size;
T *send_chunk = output_buff + chunk_offset[send_chunk_index];
auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, send_to_rank, send_chunk,
chunk_sizes[send_chunk_index] * sizeof(T));
size_t recv_chunk_index = (local_rank_ - i + rank_size) % rank_size;
T *recv_chunk = output_buff + chunk_offset[recv_chunk_index];
MS_LOG(DEBUG) << "Ring AllGather send_to_rank:" << send_to_rank << ", recv_from_rank:" << recv_from_rank
<< ", send count:" << chunk_sizes[send_chunk_index]
<< ", recv count:" << chunk_sizes[recv_chunk_index] << ", iteration:" << i;
std::shared_ptr<std::vector<unsigned char>> recv_str;
auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, recv_from_rank, &recv_str);
if (!server_node_->CollectiveWait(recv_req_id, 1)) {
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
return false;
}
memcpy_s(recv_chunk, chunk_sizes[recv_chunk_index] * sizeof(T), recv_str->data(), recv_str->size());
if (!server_node_->Wait(send_req_id, 1)) {
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
return false;
}
}
MS_LOG(DEBUG) << "End Ring AllGather.";
return true;
}
template <typename T>
bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *recvbuff, size_t count) {
uint32_t rank_size = server_num_;
uint32_t local_rank_ = server_node_->rank_id();
MS_LOG(DEBUG) << "Reduce Broadcast AllReduce rank_size:" << rank_size << ", local_rank_:" << local_rank_
<< ", count:" << count;
int ret = memcpy_s(recvbuff, count * sizeof(T), sendbuff, count * sizeof(T));
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
return false;
}
T *output_buff = reinterpret_cast<T *>(recvbuff);
// Reduce data to rank 0 process.
MS_LOG(DEBUG) << "Start Reduce to rank 0 process.";
if (local_rank_ == 0) {
std::unique_ptr<T[]> tmp_recv_buff = std::make_unique<T[]>(count);
for (uint32_t i = 1; i < rank_size; i++) {
std::shared_ptr<std::vector<unsigned char>> recv_str;
MS_LOG(DEBUG) << "Reduce rank 0 receive from rank " << i;
auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, i, &recv_str);
if (!server_node_->CollectiveWait(recv_req_id, 1)) {
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
return false;
}
memcpy_s(tmp_recv_buff.get(), count * sizeof(T), recv_str->data(), recv_str->size());
for (size_t j = 0; j < count; j++) {
output_buff[j] += tmp_recv_buff[j];
}
}
} else {
MS_LOG(DEBUG) << "Reduce send data to rank 0 process.";
auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, 0, sendbuff, count * sizeof(T));
if (!server_node_->Wait(send_req_id, 1)) {
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
return false;
}
}
MS_LOG(DEBUG) << "End Reduce.";
// Broadcast data to not 0 rank process.
MS_LOG(DEBUG) << "Start broadcast from rank 0 to other processes.";
if (local_rank_ == 0) {
for (uint32_t i = 1; i < rank_size; i++) {
MS_LOG(DEBUG) << "Broadcast data to process " << i;
auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, i, output_buff, count * sizeof(T));
if (!server_node_->Wait(send_req_id, 1)) {
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
return false;
}
}
} else {
MS_LOG(DEBUG) << "Broadcast receive from rank 0.";
std::shared_ptr<std::vector<unsigned char>> recv_str;
auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, 0, &recv_str);
if (!server_node_->CollectiveWait(recv_req_id, 1)) {
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
return false;
}
memcpy_s(output_buff, count * sizeof(T), recv_str->data(), recv_str->size());
}
MS_LOG(DEBUG) << "End broadcast.";
return true;
}
template <typename T>
bool CollectiveOpsImpl::AllReduce(const void *sendbuff, void *recvbuff, size_t count) {
// The collective communication API does not support calling Send and Recv concurrently with multiple threads;
std::unique_lock<std::mutex> lock(mtx_);
if (sendbuff == nullptr || recvbuff == nullptr) {
MS_LOG(ERROR) << "AllReduce sendbuff or recvbuff is nullptr.";
return false;
}
uint32_t rank_size = server_num_;
if (count >= rank_size) {
return RingAllReduce<T>(sendbuff, recvbuff, count);
} else {
return ReduceBroadcastAllReduce<T>(sendbuff, recvbuff, count);
}
}
template bool CollectiveOpsImpl::RingAllReduce<float>(const void *sendbuff, void *recvbuff, size_t count);
template bool CollectiveOpsImpl::RingAllReduce<size_t>(const void *sendbuff, void *recvbuff, size_t count);
template bool CollectiveOpsImpl::RingAllReduce<int>(const void *sendbuff, void *recvbuff, size_t count);
template bool CollectiveOpsImpl::ReduceBroadcastAllReduce<float>(const void *sendbuff, void *recvbuff, size_t count);
template bool CollectiveOpsImpl::ReduceBroadcastAllReduce<size_t>(const void *sendbuff, void *recvbuff, size_t count);
template bool CollectiveOpsImpl::ReduceBroadcastAllReduce<int>(const void *sendbuff, void *recvbuff, size_t count);
template bool CollectiveOpsImpl::AllReduce<float>(const void *sendbuff, void *recvbuff, size_t count);
template bool CollectiveOpsImpl::AllReduce<size_t>(const void *sendbuff, void *recvbuff, size_t count);
template bool CollectiveOpsImpl::AllReduce<int>(const void *sendbuff, void *recvbuff, size_t count);
} // namespace server
} // namespace ps
} // namespace mindspore

View File

@ -0,0 +1,71 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_COLLECTIVE_OPS_IMPL_H_
#define MINDSPORE_CCSRC_PS_SERVER_COLLECTIVE_OPS_IMPL_H_
#include <memory>
#include <string>
#include <vector>
#include <functional>
#include "proto/ps.pb.h"
#include "ps/ps_context.h"
#include "ps/core/server_node.h"
#include "ps/server/common.h"
namespace mindspore {
namespace ps {
namespace server {
// CollectiveOpsImpl is the collective communication API of the server.
// For now, it implements two AllReduce algorithms: RingAllReduce and BroadcastAllReduce. Elastic AllReduce is also
// supported for the elastic scaling feature of the server.
class CollectiveOpsImpl {
public:
static CollectiveOpsImpl &GetInstance() {
static CollectiveOpsImpl instance;
return instance;
}
void Initialize(const std::shared_ptr<core::ServerNode> &server_node);
template <typename T>
bool AllReduce(const void *sendbuff, void *recvbuff, size_t count);
private:
CollectiveOpsImpl() = default;
~CollectiveOpsImpl() = default;
CollectiveOpsImpl(const CollectiveOpsImpl &) = delete;
CollectiveOpsImpl &operator=(const CollectiveOpsImpl &) = delete;
// Implementation of RingAllReduce.
template <typename T>
bool RingAllReduce(const void *sendbuff, void *recvbuff, size_t count);
// Implementation of BroadcastAllReduce.
template <typename T>
bool ReduceBroadcastAllReduce(const void *sendbuff, void *recvbuff, size_t count);
std::shared_ptr<core::ServerNode> server_node_;
uint32_t local_rank_;
uint32_t server_num_;
// The mutex to ensure that collective communication is threadsafe.
std::mutex mtx_;
};
} // namespace server
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_COLLECTIVE_OPS_IMPL_H_

View File

@ -24,13 +24,17 @@
#include <memory> #include <memory>
#include <functional> #include <functional>
#include "proto/ps.pb.h" #include "proto/ps.pb.h"
#include "proto/fl.pb.h"
#include "ir/anf.h" #include "ir/anf.h"
#include "utils/utils.h" #include "utils/utils.h"
#include "ir/dtype/type_id.h" #include "ir/dtype/type_id.h"
#include "backend/kernel_compiler/cpu/cpu_kernel.h" #include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "schema/fl_job_generated.h"
#include "schema/cipher_generated.h"
#include "ps/ps_context.h" #include "ps/ps_context.h"
#include "ps/core/communicator/http_message_handler.h" #include "ps/core/communicator/http_message_handler.h"
#include "ps/core/communicator/tcp_server.h" #include "ps/core/communicator/tcp_server.h"
#include "ps/core/communicator/message_handler.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace ps {
@ -40,13 +44,15 @@ enum ServerMode { PARAMETER_SERVER = 0, FL_SERVER };
enum CommType { HTTP = 0, TCP }; enum CommType { HTTP = 0, TCP };
enum AggregationType { FedAvg = 0, FedAdam, FedAdagarg, FedMeta, qffl, DenseGradAccum, SparseGradAccum }; enum AggregationType { FedAvg = 0, FedAdam, FedAdagarg, FedMeta, qffl, DenseGradAccum, SparseGradAccum };
using kernel::Address; using mindspore::kernel::Address;
using kernel::AddressPtr; using mindspore::kernel::AddressPtr;
using kernel::CPUKernel; using mindspore::kernel::CPUKernel;
using FBBuilder = flatbuffers::FlatBufferBuilder;
using TimeOutCb = std::function<void(void)>; using TimeOutCb = std::function<void(void)>;
using StopTimerCb = std::function<void(void)>; using StopTimerCb = std::function<void(void)>;
using FinishIterCb = std::function<void(void)>; using FinishIterCb = std::function<void(void)>;
using FinalizeCb = std::function<void(void)>; using FinalizeCb = std::function<void(void)>;
using MessageCallback = std::function<void(const std::shared_ptr<core::MessageHandler> &)>;
// Information about whether server kernel will reuse kernel node memory from the front end. // Information about whether server kernel will reuse kernel node memory from the front end.
// Key refers to the server kernel's parameter name, like "weights", "grad", "learning_rate". // Key refers to the server kernel's parameter name, like "weights", "grad", "learning_rate".

View File

@ -0,0 +1,298 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ps/server/distributed_count_service.h"
#include <string>
#include <memory>
#include <vector>
namespace mindspore {
namespace ps {
namespace server {
void DistributedCountService::Initialize(const std::shared_ptr<core::ServerNode> &server_node,
uint32_t counting_server_rank) {
server_node_ = server_node;
MS_EXCEPTION_IF_NULL(server_node_);
communicator_ =
std::dynamic_pointer_cast<core::TcpCommunicator>(server_node_->GetOrCreateTcpComm("", 0, 0, 0, nullptr));
MS_EXCEPTION_IF_NULL(communicator_);
local_rank_ = server_node_->rank_id();
server_num_ = PSContext::instance()->initial_server_num();
counting_server_rank_ = counting_server_rank;
RegisterCallback();
return;
}
void DistributedCountService::RegisterCounter(const std::string &name, size_t global_threshold_count,
const CounterHandlers &counter_handlers) {
if (!counter_handlers.first_count_handler || !counter_handlers.last_count_handler) {
MS_LOG(EXCEPTION) << "First count handler or last count handler is not set.";
return;
}
if (global_threshold_count_.count(name) != 0) {
MS_LOG(ERROR) << "Counter for " << name << " is already set.";
return;
}
MS_LOG(INFO) << "Rank " << local_rank_ << " register counter for " << name << " count:" << global_threshold_count;
// If the server is the leader server, it needs to set the counter handlers and do the real counting.
if (local_rank_ == counting_server_rank_) {
global_current_count_[name] = {};
global_threshold_count_[name] = global_threshold_count;
mutex_[name];
}
counter_handlers_[name] = counter_handlers;
return;
}
bool DistributedCountService::Count(const std::string &name, const std::string &id) {
MS_LOG(INFO) << "Rank " << local_rank_ << " reports count for " << name << " of " << id;
if (local_rank_ == counting_server_rank_) {
if (global_threshold_count_.count(name) == 0) {
MS_LOG(ERROR) << "Counter for " << name << " is not registered.";
return false;
}
std::unique_lock<std::mutex> lock(mutex_[name]);
if (global_current_count_[name].size() >= global_threshold_count_[name]) {
MS_LOG(ERROR) << "Count for " << name << " is already enough. Threshold count is "
<< global_threshold_count_[name];
return false;
}
MS_LOG(INFO) << "Leader server increase count for " << name << " of " << id;
global_current_count_[name].insert(id);
TriggerCounterEvent(name);
} else {
// If this server is a follower server, it needs to send CountRequest to the leader server.
CountRequest report_count_req;
report_count_req.set_name(name);
report_count_req.set_id(id);
std::shared_ptr<std::vector<unsigned char>> report_cnt_rsp_msg = nullptr;
if (!communicator_->SendPbRequest(report_count_req, counting_server_rank_, core::TcpUserCommand::kCount,
&report_cnt_rsp_msg)) {
MS_LOG(ERROR) << "Sending reporting count message to leader server failed for " << name;
return false;
}
CountResponse count_rsp;
count_rsp.ParseFromArray(report_cnt_rsp_msg->data(), report_cnt_rsp_msg->size());
if (!count_rsp.result()) {
MS_LOG(ERROR) << "Reporting count failed:" << count_rsp.reason();
return false;
}
}
return true;
}
bool DistributedCountService::CountReachThreshold(const std::string &name) {
MS_LOG(INFO) << "Rank " << local_rank_ << " query whether count reaches threshold for " << name;
if (local_rank_ == counting_server_rank_) {
if (global_threshold_count_.count(name) == 0) {
MS_LOG(ERROR) << "Counter for " << name << " is not set.";
return false;
}
std::unique_lock<std::mutex> lock(mutex_[name]);
return global_current_count_[name].size() == global_threshold_count_[name];
} else {
CountReachThresholdRequest count_reach_threashold_req;
count_reach_threashold_req.set_name(name);
std::shared_ptr<std::vector<unsigned char>> query_cnt_enough_rsp_msg = nullptr;
if (!communicator_->SendPbRequest(count_reach_threashold_req, counting_server_rank_,
core::TcpUserCommand::kReachThreshold, &query_cnt_enough_rsp_msg)) {
MS_LOG(ERROR) << "Sending querying whether count reaches threshold message to leader server failed for " << name;
return false;
}
CountReachThresholdResponse count_reach_threashold_rsp;
count_reach_threashold_rsp.ParseFromArray(query_cnt_enough_rsp_msg->data(), query_cnt_enough_rsp_msg->size());
return count_reach_threashold_rsp.is_enough();
}
}
void DistributedCountService::ResetCounter(const std::string &name) {
if (local_rank_ == counting_server_rank_) {
MS_LOG(INFO) << "Leader server reset count for " << name;
global_current_count_[name].clear();
}
return;
}
void DistributedCountService::RegisterCallback() {
if (local_rank_ == counting_server_rank_) {
communicator_->RegisterMsgCallBack(
"count", std::bind(&DistributedCountService::HandleCountRequest, this, std::placeholders::_1));
communicator_->RegisterMsgCallBack(
"countReachThreshold",
std::bind(&DistributedCountService::HandleCountReachThresholdRequest, this, std::placeholders::_1));
}
// The callback of first/last event must be set in both leader server and follower servers.
communicator_->RegisterMsgCallBack(
"counterEvent", std::bind(&DistributedCountService::HandleCounterEvent, this, std::placeholders::_1));
}
void DistributedCountService::HandleCountRequest(const std::shared_ptr<core::MessageHandler> &message) {
if (message == nullptr) {
MS_LOG(ERROR) << "Message is nullptr.";
return;
}
CountRequest report_count_req;
report_count_req.ParseFromArray(message->data(), message->len());
const std::string &name = report_count_req.name();
const std::string &id = report_count_req.id();
CountResponse count_rsp;
std::unique_lock<std::mutex> lock(mutex_[name]);
// If leader server has no counter for the name registered, return an error.
if (global_threshold_count_.count(name) == 0) {
std::string reason = "Counter for " + name + " is not registered.";
count_rsp.set_result(false);
count_rsp.set_reason(reason);
MS_LOG(ERROR) << reason;
communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(), message);
return;
}
// If leader server already has enough count for the name, return an error.
if (global_current_count_[name].size() >= global_threshold_count_[name]) {
std::string reason =
"Count for " + name + " is already enough. Threshold count is " + std::to_string(global_threshold_count_[name]);
count_rsp.set_result(false);
count_rsp.set_reason(reason);
MS_LOG(ERROR) << reason;
communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(), message);
return;
}
// Insert the id for the counter, which means the count for the name is increased.
MS_LOG(INFO) << "Leader server increase count for " << name << " of " << id;
global_current_count_[name].insert(id);
TriggerCounterEvent(name);
count_rsp.set_result(true);
count_rsp.set_reason("success");
communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(), message);
return;
}
void DistributedCountService::HandleCountReachThresholdRequest(const std::shared_ptr<core::MessageHandler> &message) {
if (message == nullptr) {
MS_LOG(ERROR) << "Message is nullptr.";
return;
}
CountReachThresholdRequest count_reach_threashold_req;
count_reach_threashold_req.ParseFromArray(message->data(), message->len());
const std::string &name = count_reach_threashold_req.name();
std::unique_lock<std::mutex> lock(mutex_[name]);
if (global_threshold_count_.count(name) == 0) {
MS_LOG(ERROR) << "Counter for " << name << " is not registered.";
return;
}
CountReachThresholdResponse count_reach_threashold_rsp;
count_reach_threashold_rsp.set_is_enough(global_current_count_[name].size() == global_threshold_count_[name]);
communicator_->SendResponse(count_reach_threashold_rsp.SerializeAsString().data(),
count_reach_threashold_rsp.SerializeAsString().size(), message);
return;
}
void DistributedCountService::HandleCounterEvent(const std::shared_ptr<core::MessageHandler> &message) {
if (message == nullptr) {
MS_LOG(ERROR) << "Message is nullptr.";
return;
}
// Respond as soon as possible so the leader server won't wait for each follower servers to finish calling the
// callbacks.
std::string couter_event_rsp_msg = "success";
communicator_->SendResponse(couter_event_rsp_msg.data(), couter_event_rsp_msg.size(), message);
CounterEvent counter_event;
counter_event.ParseFromArray(message->data(), message->len());
const auto &type = counter_event.type();
const auto &name = counter_event.name();
MS_LOG(INFO) << "Rank " << local_rank_ << " do counter event " << type << " for " << name;
if (type == CounterEventType::FIRST_CNT) {
counter_handlers_[name].first_count_handler(message);
} else if (type == CounterEventType::LAST_CNT) {
counter_handlers_[name].last_count_handler(message);
} else {
MS_LOG(ERROR) << "DistributedCountService event type " << type << " is invalid.";
return;
}
return;
}
void DistributedCountService::TriggerCounterEvent(const std::string &name) {
MS_LOG(INFO) << "Current count for " << name << " is " << global_current_count_[name].size()
<< ", threshold count is " << global_threshold_count_[name];
// The threshold count may be 1 so the first and last count event should be both activated.
if (global_current_count_[name].size() == 1) {
TriggerFirstCountEvent(name);
}
if (global_current_count_[name].size() == global_threshold_count_[name]) {
TriggerLastCountEvent(name);
}
return;
}
void DistributedCountService::TriggerFirstCountEvent(const std::string &name) {
MS_LOG(INFO) << "Activating first count event for " << name;
CounterEvent first_count_event;
first_count_event.set_type(CounterEventType::FIRST_CNT);
first_count_event.set_name(name);
// Broadcast to all follower servers.
for (uint32_t i = 1; i < server_num_; i++) {
if (!communicator_->SendPbRequest(first_count_event, i, core::TcpUserCommand::kCounterEvent)) {
MS_LOG(ERROR) << "Activating first count event to server " << i << " failed.";
return;
}
}
// Leader server directly calls the callback.
counter_handlers_[name].first_count_handler(nullptr);
return;
}
void DistributedCountService::TriggerLastCountEvent(const std::string &name) {
MS_LOG(INFO) << "Activating last count event for " << name;
CounterEvent last_count_event;
last_count_event.set_type(CounterEventType::LAST_CNT);
last_count_event.set_name(name);
// Broadcast to all follower servers.
for (uint32_t i = 1; i < server_num_; i++) {
if (!communicator_->SendPbRequest(last_count_event, i, core::TcpUserCommand::kCounterEvent)) {
MS_LOG(ERROR) << "Activating last count event to server " << i << " failed.";
return;
}
}
// Leader server directly calls the callback.
counter_handlers_[name].last_count_handler(nullptr);
return;
}
} // namespace server
} // namespace ps
} // namespace mindspore

View File

@ -0,0 +1,126 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_COUNT_SERVICE_H_
#define MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_COUNT_SERVICE_H_
#include <set>
#include <string>
#include <memory>
#include <unordered_map>
#include "proto/ps.pb.h"
#include "ps/server/common.h"
#include "ps/core/server_node.h"
#include "ps/core/communicator/tcp_communicator.h"
namespace mindspore {
namespace ps {
namespace server {
// The callbacks for the first count and last count event.
typedef struct {
MessageCallback first_count_handler;
MessageCallback last_count_handler;
} CounterHandlers;
// DistributedCountService is used for counting in the server cluster dimension. It's used for counting of rounds,
// aggregation counting, etc.
// The counting could be called by any server, but only one server has the information
// of the cluster count and we mark this server as the counting server. Other servers must communicate with this
// counting server to increase/query count number.
// On the first count or last count event, DistributedCountService on the counting server triggers the event on other
// servers by sending counter event commands. This is for the purpose of keeping server cluster's consistency.
class DistributedCountService {
public:
static DistributedCountService &GetInstance() {
static DistributedCountService instance;
return instance;
}
// Initialize counter service with the server node because communication is needed.
void Initialize(const std::shared_ptr<core::ServerNode> &server_node, uint32_t counting_server_rank);
// Register counter to the counting server for the name with its threshold count in server cluster dimension and
// first/last count event callbacks.
void RegisterCounter(const std::string &name, size_t global_threshold_count, const CounterHandlers &counter_handlers);
// Report a count to the counting server. Parameter 'id' is in case of repeated counting.
bool Count(const std::string &name, const std::string &id);
// Query whether the count reaches the threshold count for the name. If the count is the same as the threshold count,
// this method returns true.
bool CountReachThreshold(const std::string &name);
// Reset the count of the name to 0.
void ResetCounter(const std::string &name);
// Returns the server rank because in some cases the callers use this rank as the 'id' for method
// Count.
uint32_t local_rank() { return local_rank_; }
private:
DistributedCountService() = default;
~DistributedCountService() = default;
DistributedCountService(const DistributedCountService &) = delete;
DistributedCountService &operator=(const DistributedCountService &) = delete;
// Register callbacks of the counting server to handle messages sent by the other servers.
void RegisterCallback();
// Callback for the reporting count message from other servers. Only counting server will call this method.
void HandleCountRequest(const std::shared_ptr<core::MessageHandler> &message);
// Callback for the querying whether threshold count is reached message from other servers. Only counting
// server will call this method.
void HandleCountReachThresholdRequest(const std::shared_ptr<core::MessageHandler> &message);
// Callback for the first/last event message from the counting server. Only other servers will call this
// method.
void HandleCounterEvent(const std::shared_ptr<core::MessageHandler> &message);
// Call the callbacks when the first/last count event is triggered.
void TriggerCounterEvent(const std::string &name);
void TriggerFirstCountEvent(const std::string &name);
void TriggerLastCountEvent(const std::string &name);
// Members for the communication between counting server and other servers.
std::shared_ptr<core::ServerNode> server_node_;
std::shared_ptr<core::TcpCommunicator> communicator_;
uint32_t local_rank_;
uint32_t server_num_;
// Only one server will be set to do the real counting.
uint32_t counting_server_rank_;
// Key: name, e.g, startFLJob, updateModel, push.
// Value: a set of id without repeatation because each work may report multiple times.
std::unordered_map<std::string, std::set<std::string>> global_current_count_;
// Key: name, e.g, StartFLJobCount.
// Value: global threshold count in the server cluster dimension for this name.
std::unordered_map<std::string, size_t> global_threshold_count_;
// First/last count event callbacks of the name.
std::unordered_map<std::string, CounterHandlers> counter_handlers_;
// Because the count is increased/queried conccurently, we must ensure the operations are threadsafe.
std::unordered_map<std::string, std::mutex> mutex_;
};
} // namespace server
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_COUNT_SERVICE_H_

View File

@ -0,0 +1,201 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ps/server/distributed_metadata_store.h"
#include <memory>
#include <string>
#include <vector>
namespace mindspore {
namespace ps {
namespace server {
void DistributedMetadataStore::Initialize(const std::shared_ptr<core::ServerNode> &server_node) {
server_node_ = server_node;
MS_EXCEPTION_IF_NULL(server_node);
communicator_ =
std::dynamic_pointer_cast<core::TcpCommunicator>(server_node_->GetOrCreateTcpComm("", 0, 0, 0, nullptr));
MS_EXCEPTION_IF_NULL(communicator_);
local_rank_ = server_node_->rank_id();
server_num_ = PSContext::instance()->initial_server_num();
InitHashRing();
RegisterCallback();
return;
}
void DistributedMetadataStore::RegisterMetadata(const std::string &name, const PBMetadata &meta) {
if (router_ == nullptr) {
MS_LOG(ERROR) << "The consistent hash ring is not initialized yet.";
return;
}
uint32_t stored_rank = router_->Find(name);
if (local_rank_ == stored_rank) {
if (metadata_.count(name) != 0) {
MS_LOG(ERROR) << "The metadata for " << name << " is already registered.";
return;
}
MS_LOG(INFO) << "Rank " << local_rank_ << " register storage for metadata " << name;
metadata_[name] = meta;
mutex_[name];
}
return;
}
void DistributedMetadataStore::ResetMetadata(const std::string &name) {
if (router_ == nullptr) {
MS_LOG(ERROR) << "The consistent hash ring is not initialized yet.";
return;
}
uint32_t stored_rank = router_->Find(name);
if (local_rank_ == stored_rank) {
if (metadata_.count(name) == 0) {
MS_LOG(ERROR) << "The metadata for " << name << " is not registered.";
return;
}
MS_LOG(INFO) << "Rank " << local_rank_ << " reset metadata for " << name;
std::unique_lock<std::mutex> lock(mutex_[name]);
PBMetadata empty_meta;
metadata_[name] = empty_meta;
}
return;
}
void DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBMetadata &meta) {
if (router_ == nullptr) {
MS_LOG(ERROR) << "The consistent hash ring is not initialized yet.";
return;
}
uint32_t stored_rank = router_->Find(name);
MS_LOG(INFO) << "Rank " << local_rank_ << " update value for " << name << " which is stored in rank " << stored_rank;
if (local_rank_ == stored_rank) {
if (!DoUpdateMetadata(name, meta)) {
MS_LOG(ERROR) << "Updating meta data failed.";
return;
}
} else {
PBMetadataWithName metadata_with_name;
metadata_with_name.set_name(name);
*metadata_with_name.mutable_metadata() = meta;
if (!communicator_->SendPbRequest(metadata_with_name, stored_rank, core::TcpUserCommand::kUpdateMetadata)) {
MS_LOG(ERROR) << "Sending updating metadata message to server " << stored_rank << " failed.";
return;
}
}
return;
}
PBMetadata DistributedMetadataStore::GetMetadata(const std::string &name) {
if (router_ == nullptr) {
MS_LOG(ERROR) << "The consistent hash ring is not initialized yet.";
return {};
}
uint32_t stored_rank = router_->Find(name);
MS_LOG(INFO) << "Rank " << local_rank_ << " get metadata for " << name << " which is stored in rank " << stored_rank;
if (local_rank_ == stored_rank) {
std::unique_lock<std::mutex> lock(mutex_[name]);
return metadata_[name];
} else {
GetMetadataRequest get_metadata_req;
get_metadata_req.set_name(name);
PBMetadata get_metadata_rsp;
std::shared_ptr<std::vector<unsigned char>> get_meta_rsp_msg = nullptr;
if (!communicator_->SendPbRequest(get_metadata_req, stored_rank, core::TcpUserCommand::kGetMetadata,
&get_meta_rsp_msg)) {
MS_LOG(ERROR) << "Sending getting metadata message to server " << stored_rank << " failed.";
return get_metadata_rsp;
}
get_metadata_rsp.ParseFromArray(get_meta_rsp_msg->data(), get_meta_rsp_msg->size());
return get_metadata_rsp;
}
}
void DistributedMetadataStore::InitHashRing() {
router_ = std::make_shared<ConsistentHashRing>(32);
MS_EXCEPTION_IF_NULL(router_);
for (uint32_t i = 0; i < server_num_; i++) {
bool ret = router_->Insert(i);
if (!ret) {
MS_LOG(EXCEPTION) << "Add node " << i << " to router of meta storage failed.";
return;
}
}
return;
}
void DistributedMetadataStore::RegisterCallback() {
communicator_->RegisterMsgCallBack(
"updateMetadata", std::bind(&DistributedMetadataStore::HandleUpdateMetadataRequest, this, std::placeholders::_1));
communicator_->RegisterMsgCallBack(
"getMetadata", std::bind(&DistributedMetadataStore::HandleGetMetadataRequest, this, std::placeholders::_1));
return;
}
void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr<core::MessageHandler> &message) {
if (message == nullptr) {
MS_LOG(ERROR) << "Message is nullptr.";
return;
}
PBMetadataWithName meta_with_name;
meta_with_name.ParseFromArray(message->data(), message->len());
const std::string &name = meta_with_name.name();
MS_LOG(INFO) << "Update metadata for " << name;
std::string update_meta_rsp_msg;
if (!DoUpdateMetadata(name, meta_with_name.metadata())) {
update_meta_rsp_msg = "Updating meta data failed.";
} else {
update_meta_rsp_msg = "Success";
}
communicator_->SendResponse(update_meta_rsp_msg.data(), update_meta_rsp_msg.size(), message);
return;
}
void DistributedMetadataStore::HandleGetMetadataRequest(const std::shared_ptr<core::MessageHandler> &message) {
if (message == nullptr) {
MS_LOG(ERROR) << "Message is nullptr.";
return;
}
GetMetadataRequest get_metadata_req;
get_metadata_req.ParseFromArray(message->data(), message->len());
const std::string &name = get_metadata_req.name();
MS_LOG(INFO) << "Getting metadata for " << name;
std::unique_lock<std::mutex> lock(mutex_[name]);
PBMetadata stored_meta = metadata_[name];
std::string getting_meta_rsp_msg = stored_meta.SerializeAsString();
communicator_->SendResponse(getting_meta_rsp_msg.data(), getting_meta_rsp_msg.size(), message);
return;
}
bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const PBMetadata &meta) {
std::unique_lock<std::mutex> lock(mutex_[name]);
metadata_[name] = meta;
return true;
}
} // namespace server
} // namespace ps
} // namespace mindspore

View File

@ -0,0 +1,101 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_META_STORE_H_
#define MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_META_STORE_H_
#include <string>
#include <memory>
#include <unordered_map>
#include "proto/ps.pb.h"
#include "ps/server/common.h"
#include "ps/core/server_node.h"
#include "ps/core/communicator/tcp_communicator.h"
#include "ps/server/consistent_hash_ring.h"
namespace mindspore {
namespace ps {
namespace server {
// This class is used for distributed metadata storage using consistent hash. All metadata is distributedly
// stored in all servers. Caller doesn't need to know which server stores the metadata. It only needs to know what kind
// of operations should be done to the metadata.
// The metadata stored in the server is in protobuffer format because it's easy for serializing and communicating. The
// type of the protobuffer struct is decided by the caller using protobuffer's API.
class DistributedMetadataStore {
public:
static DistributedMetadataStore &GetInstance() {
static DistributedMetadataStore instance;
return instance;
}
// Initialize metadata storage with the server node because communication is needed.
void Initialize(const std::shared_ptr<core::ServerNode> &server_node);
// Register metadata for the name with the initial value. This method should be only called once for each name.
void RegisterMetadata(const std::string &name, const PBMetadata &meta);
// Reset the metadata value for the name.
void ResetMetadata(const std::string &name);
// Update the metadata for the name.
void UpdateMetadata(const std::string &name, const PBMetadata &meta);
// Get the metadata for the name.
PBMetadata GetMetadata(const std::string &name);
private:
DistributedMetadataStore() = default;
~DistributedMetadataStore() = default;
DistributedMetadataStore(const DistributedMetadataStore &) = delete;
DistributedMetadataStore &operator=(const DistributedMetadataStore &) = delete;
// Initialize the consistent hash ring for distributed storage.
void InitHashRing();
// Register callbacks for the server to handle update/get metadata messages from other servers.
void RegisterCallback();
// Callback for updating metadata request sent to the server.
void HandleUpdateMetadataRequest(const std::shared_ptr<core::MessageHandler> &message);
// Callback for getting metadata request sent to the server.
void HandleGetMetadataRequest(const std::shared_ptr<core::MessageHandler> &message);
// Do updating metadata in the server where the metadata for the name is stored.
bool DoUpdateMetadata(const std::string &name, const PBMetadata &meta);
// Members for the communication between servers.
std::shared_ptr<core::ServerNode> server_node_;
std::shared_ptr<core::TcpCommunicator> communicator_;
uint32_t local_rank_;
uint32_t server_num_;
// Consistent hash ring. This is used for DistributedMetadataStore to find which server node the meta data is stored.
std::shared_ptr<ConsistentHashRing> router_;
// We store metadata which is serialized by ProtoBuffer so that data storage and data transmission API is easy to use.
// Key: data name.
// Value: ProtoBuffer Struct.
std::unordered_map<std::string, PBMetadata> metadata_;
// Because the metadata is read/written conccurently, we must ensure the operations are threadsafe.
std::unordered_map<std::string, std::mutex> mutex_;
};
} // namespace server
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_META_STORE_H_

View File

@ -169,7 +169,7 @@ bool Executor::HandleOverwriteWeightsByKey(const std::map<std::string, Address>
} }
AddressPtr Executor::HandlePull(const std::string &param_name) { AddressPtr Executor::HandlePull(const std::string &param_name) {
MS_LOG(INFO) << "Handle blocking pull msg for parameter " << param_name; MS_LOG(INFO) << "Handle blocking pull message for parameter " << param_name;
if (param_aggrs_.count(param_name) == 0) { if (param_aggrs_.count(param_name) == 0) {
MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server."; MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server.";
return nullptr; return nullptr;
@ -193,11 +193,6 @@ AddressPtr Executor::HandlePull(const std::string &param_name) {
return addr; return addr;
} }
std::map<std::string, AddressPtr> Executor::HandleAsyncGetModel() {
std::unique_lock<std::mutex> lock(model_mutex_);
return GetModel();
}
std::map<std::string, AddressPtr> Executor::HandleGetWeightsByKey(const std::vector<std::string> &param_names) { std::map<std::string, AddressPtr> Executor::HandleGetWeightsByKey(const std::vector<std::string> &param_names) {
std::map<std::string, AddressPtr> weights; std::map<std::string, AddressPtr> weights;
for (const auto &param_name : param_names) { for (const auto &param_name : param_names) {

View File

@ -63,10 +63,6 @@ class Executor {
// asynchronously. // asynchronously.
bool HandleModelUpdateAsync(const std::map<std::string, UploadData> &feature_map); bool HandleModelUpdateAsync(const std::map<std::string, UploadData> &feature_map);
// Called in asynchronous federated learning training mode. Returns whole model in key-value where key refers to the
// parameter name.
std::map<std::string, AddressPtr> HandleAsyncGetModel();
// Forcibly overwrite specific weights in overwriteWeights message. // Forcibly overwrite specific weights in overwriteWeights message.
bool HandleOverwriteWeightsByKey(const std::map<std::string, Address> &feature_map); bool HandleOverwriteWeightsByKey(const std::map<std::string, Address> &feature_map);

View File

@ -0,0 +1,76 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ps/server/iteration.h"
#include <memory>
#include <vector>
#include <numeric>
#include "ps/server/model_store.h"
namespace mindspore {
namespace ps {
namespace server {
Iteration::Iteration() : iteration_num_(1) { LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_); }
void Iteration::AddRound(const std::shared_ptr<Round> &round) {
MS_EXCEPTION_IF_NULL(round);
rounds_.push_back(round);
}
void Iteration::InitRounds(const std::vector<std::shared_ptr<core::CommunicatorBase>> &communicators,
const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb) {
if (communicators.empty()) {
MS_LOG(EXCEPTION) << "Communicators for rounds is empty.";
return;
}
std::for_each(communicators.begin(), communicators.end(),
[&](const std::shared_ptr<core::CommunicatorBase> &communicator) {
for (auto &round : rounds_) {
if (round == nullptr) {
continue;
}
round->Initialize(communicator, timeout_cb, finish_iteration_cb);
}
});
// The time window for one iteration, which will be used in some round kernels.
size_t iteration_time_window =
std::accumulate(rounds_.begin(), rounds_.end(), 0,
[](size_t total, const std::shared_ptr<Round> &round) { return total + round->time_window(); });
LocalMetaStore::GetInstance().put_value(kCtxTotalTimeoutDuration, iteration_time_window);
return;
}
void Iteration::ProceedToNextIter() {
iteration_num_ = LocalMetaStore::GetInstance().curr_iter_num();
// Store the model for each iteration.
const auto &model = Executor::GetInstance().GetModel();
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
for (auto &round : rounds_) {
round->Reset();
}
iteration_num_++;
LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
MS_LOG(INFO) << "Proceed to next iteration:" << iteration_num_ << "\n";
}
const std::vector<std::shared_ptr<Round>> &Iteration::rounds() { return rounds_; }
} // namespace server
} // namespace ps
} // namespace mindspore

View File

@ -0,0 +1,58 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_ITERATION_H_
#define MINDSPORE_CCSRC_PS_SERVER_ITERATION_H_
#include <memory>
#include <vector>
#include "ps/core/communicator/communicator_base.h"
#include "ps/server/common.h"
#include "ps/server/round.h"
#include "ps/server/local_meta_store.h"
namespace mindspore {
namespace ps {
namespace server {
// In server's logic, Iteration is the minimum execution unit. For each execution, it consists of multiple kinds of
// Rounds, only after all the rounds are finished, this iteration is considered as completed.
class Iteration {
public:
Iteration();
~Iteration() = default;
// Add a round for the iteration. This method will be called multiple times for each round.
void AddRound(const std::shared_ptr<Round> &round);
// Initialize all the rounds in the iteration.
void InitRounds(const std::vector<std::shared_ptr<core::CommunicatorBase>> &communicators,
const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb);
// The server proceeds to the next iteration only after the last iteration finishes.
void ProceedToNextIter();
const std::vector<std::shared_ptr<Round>> &rounds();
private:
std::vector<std::shared_ptr<Round>> rounds_;
// Server's current iteration number.
size_t iteration_num_;
};
} // namespace server
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_ITERATION_H_

View File

@ -0,0 +1,127 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ps/server/kernel/round/round_kernel.h"
#include <mutex>
#include <queue>
#include <chrono>
#include <thread>
#include <utility>
#include <string>
#include <vector>
namespace mindspore {
namespace ps {
namespace server {
namespace kernel {
RoundKernel::RoundKernel() : name_(""), current_count_(0), required_count_(0), error_reason_("") {
release_thread_ = std::thread([&]() {
while (true) {
std::unique_lock<std::mutex> release_lock(release_mtx_);
// Detect whether there's any data needs to be released every 100 milliseconds.
if (heap_data_to_release_.empty()) {
release_lock.unlock();
std::this_thread::sleep_for(std::chrono::milliseconds(100));
continue;
}
AddressPtr addr_ptr = heap_data_to_release_.front();
heap_data_to_release_.pop();
release_lock.unlock();
std::unique_lock<std::mutex> heap_data_lock(heap_data_mtx_);
if (heap_data_.count(addr_ptr) == 0) {
MS_LOG(ERROR) << "The data is not stored.";
continue;
}
// Manually release unique_ptr data.
heap_data_[addr_ptr].reset(nullptr);
heap_data_.erase(heap_data_.find(addr_ptr));
}
});
release_thread_.detach();
}
void RoundKernel::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message) { return; }
void RoundKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) { return; }
void RoundKernel::StopTimer() {
if (stop_timer_cb_) {
stop_timer_cb_();
}
return;
}
void RoundKernel::FinishIteration() {
if (finish_iteration_cb_) {
finish_iteration_cb_();
}
return;
}
void RoundKernel::Release(AddressPtr addr_ptr) {
if (addr_ptr == nullptr) {
MS_LOG(ERROR) << "Data to be released is empty.";
return;
}
std::unique_lock<std::mutex> lock(release_mtx_);
heap_data_to_release_.push(addr_ptr);
return;
}
void RoundKernel::set_name(const std::string &name) { name_ = name; }
void RoundKernel::set_stop_timer_cb(StopTimerCb timer_stopper) { stop_timer_cb_ = timer_stopper; }
void RoundKernel::set_finish_iteration_cb(FinishIterCb finish_iteration_cb) {
finish_iteration_cb_ = finish_iteration_cb;
}
void RoundKernel::GenerateOutput(const std::vector<AddressPtr> &outputs, void *data, size_t len) {
if (data == nullptr) {
MS_LOG(ERROR) << "The data is nullptr.";
return;
}
if (outputs.empty()) {
MS_LOG(ERROR) << "Generating output failed. Outputs size is empty.";
return;
}
std::unique_ptr<unsigned char[]> output_data = std::make_unique<unsigned char[]>(len);
if (output_data == nullptr) {
MS_LOG(ERROR) << "Output data is nullptr.";
return;
}
size_t dst_size = len;
int ret = memcpy_s(output_data.get(), dst_size, data, len);
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
return;
}
outputs[0]->addr = output_data.get();
outputs[0]->size = len;
std::unique_lock<std::mutex> lock(heap_data_mtx_);
heap_data_.insert(std::make_pair(outputs[0], std::move(output_data)));
return;
}
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace mindspore

View File

@ -0,0 +1,130 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_
#include <map>
#include <memory>
#include <string>
#include <vector>
#include <mutex>
#include <queue>
#include <utility>
#include <chrono>
#include <thread>
#include <unordered_map>
#include "backend/kernel_compiler/common_utils.h"
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "ps/server/common.h"
#include "ps/server/local_meta_store.h"
#include "ps/server/distributed_count_service.h"
#include "ps/server/distributed_metadata_store.h"
namespace mindspore {
namespace ps {
namespace server {
namespace kernel {
// RoundKernel contains the main logic of server handling messages from workers. One iteration has multiple round
// kernels to represent the process. They receive and parse messages from the server communication module. After
// handling these messages, round kernels allocate response data and send it back.
// For example, the main process of federated learning is:
// startFLJob round->updateModel round->getModel round.
class RoundKernel : virtual public CPUKernel {
public:
RoundKernel();
virtual ~RoundKernel() = default;
// RoundKernel doesn't use InitKernel method of base class CPUKernel to initialize. So implementation of this
// inherited method is empty.
void InitKernel(const CNodePtr &kernel_node) override {}
// Initialize RoundKernel with threshold_count which means that for every iteration, this round needs threshold_count
// messages.
virtual void InitKernel(size_t threshold_count) = 0;
// Launch the round kernel logic to handle the message passed by the communication module.
virtual bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) = 0;
// The callbacks when first message and last message for this round kernel is received.
// These methods is called by class DistributedCountService and triggered by leader server(Rank 0).
// virtual void OnFirstCountEvent(std::shared_ptr<core::MessageHandler> message);
// virtual void OnLastCnt(std::shared_ptr<core::MessageHandler> message);
// Some rounds could be stateful in a iteration. Reset method resets the status of this round.
virtual bool Reset() = 0;
// The counter event handlers for DistributedCountService.
virtual void OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message);
virtual void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message);
// Called when this round is finished. This round timer's Stop method will be called.
void StopTimer();
// Called after this iteration(including all rounds) is finished. All rounds' Reset method will
// be called.
void FinishIteration();
// Release the response data allocated inside the round kernel.
// Server framework must call this after the response data is sent back.
void Release(AddressPtr addr_ptr);
// Set round kernel name, which could be used in round kernel's methods.
void set_name(const std::string &name);
// Set callbacks to be called under certain triggered conditions.
void set_stop_timer_cb(StopTimerCb timer_stopper);
void set_finish_iteration_cb(FinishIterCb finish_iteration_cb);
protected:
// Generating response data of this round. The data is allocated on the heap to ensure it's not released before sent
// back to worker.
void GenerateOutput(const std::vector<AddressPtr> &outputs, void *data, size_t len);
// Round kernel's name.
std::string name_;
// The current received message count for this round in this iteration.
size_t current_count_;
// The required received message count for this round in one iteration.
size_t required_count_;
// The reason causes the error in this round kernel.
std::string error_reason_;
StopTimerCb stop_timer_cb_;
FinishIterCb finish_iteration_cb_;
// Members below are used for allocating and releasing response data on the heap.
// To ensure the performance, we use another thread to release data on the heap. So the operation on the data should
// be threadsafe.
std::thread release_thread_;
// Data needs to be released and its mutex;
std::mutex release_mtx_;
std::queue<AddressPtr> heap_data_to_release_;
std::mutex heap_data_mtx_;
std::unordered_map<AddressPtr, std::unique_ptr<unsigned char[]>> heap_data_;
};
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_

View File

@ -0,0 +1,44 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ps/server/kernel/round/round_kernel_factory.h"
namespace mindspore {
namespace ps {
namespace server {
namespace kernel {
RoundKernelFactory &RoundKernelFactory::GetInstance() {
static RoundKernelFactory instance;
return instance;
}
void RoundKernelFactory::Register(const std::string &name, RoundKernelCreator &&creator) {
name_to_creator_map_[name] = creator;
}
std::shared_ptr<RoundKernel> RoundKernelFactory::Create(const std::string &name) {
if (name_to_creator_map_.count(name) == 0) {
MS_LOG(ERROR) << "Round kernel " << name << " is not registered.";
return nullptr;
}
auto kernel = name_to_creator_map_[name]();
kernel->set_name(name);
return kernel;
}
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace mindspore

View File

@ -0,0 +1,62 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_
#include <memory>
#include <string>
#include <utility>
#include <unordered_map>
#include "ps/server/common.h"
#include "ps/server/kernel/round/round_kernel.h"
namespace mindspore {
namespace ps {
namespace server {
namespace kernel {
using RoundKernelCreator = std::function<std::shared_ptr<RoundKernel>()>;
// Kernel factory of round kernels.
class RoundKernelFactory {
public:
static RoundKernelFactory &GetInstance();
void Register(const std::string &name, RoundKernelCreator &&creator);
std::shared_ptr<RoundKernel> Create(const std::string &name);
private:
RoundKernelFactory() = default;
~RoundKernelFactory() = default;
RoundKernelFactory(const RoundKernelFactory &) = delete;
RoundKernelFactory &operator=(const RoundKernelFactory &) = delete;
std::unordered_map<std::string, RoundKernelCreator> name_to_creator_map_;
};
class RoundKernelRegister {
public:
RoundKernelRegister(const std::string &name, RoundKernelCreator &&creator) {
RoundKernelFactory::GetInstance().Register(name, std::move(creator));
}
};
#define REG_ROUND_KERNEL(NAME, CLASS) \
static_assert(std::is_base_of<RoundKernel, CLASS>::value, " must be base of RoundKernel"); \
static const RoundKernelRegister g_##NAME##_round_kernel_reg(#NAME, []() { return std::make_shared<CLASS>(); });
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_

View File

@ -0,0 +1,192 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ps/server/kernel/round/start_fl_job_kernel.h"
#include <map>
#include <memory>
#include <string>
#include <vector>
namespace mindspore {
namespace ps {
namespace server {
namespace kernel {
void StartFLJobKernel::InitKernel(size_t) {
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
}
executor_ = &Executor::GetInstance();
MS_EXCEPTION_IF_NULL(executor_);
if (!executor_->initialized()) {
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
return;
}
PBMetadata devices_metas;
DistributedMetadataStore::GetInstance().RegisterMetadata(kCtxDeviceMetas, devices_metas);
return;
}
bool StartFLJobKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
MS_LOG(INFO) << "Launching StartFLJobKernel kernel.";
if (inputs.size() != 1 || outputs.size() != 1) {
MS_LOG(ERROR) << "inputs or outputs size is invalid.";
return false;
}
void *req_data = inputs[0]->addr;
const std::shared_ptr<FBBuilder> &fbb = std::make_shared<FBBuilder>();
if (fbb == nullptr || req_data == nullptr) {
MS_LOG(ERROR) << "FBBuilder builder or req_data is nullptr.";
return false;
}
if (ReachThresholdForStartFLJob(fbb)) {
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return false;
}
const schema::RequestFLJob *start_fl_job_req = flatbuffers::GetRoot<schema::RequestFLJob>(req_data);
DeviceMeta device_meta = CreateDeviceMetadata(start_fl_job_req);
if (!ReadyForStartFLJob(fbb, device_meta)) {
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return false;
}
// If calling ReportCount before ReadyForStartFLJob, the result will be inconsistent if the device is not selected.
if (!CountForStartFLJob(fbb, start_fl_job_req)) {
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return false;
}
StartFLJob(fbb, device_meta);
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
bool StartFLJobKernel::Reset() {
MS_LOG(INFO) << "Starting fl job kernel reset!";
StopTimer();
DistributedCountService::GetInstance().ResetCounter(name_);
DistributedMetadataStore::GetInstance().ResetMetadata(kCtxDeviceMetas);
return true;
}
bool StartFLJobKernel::ReachThresholdForStartFLJob(const std::shared_ptr<FBBuilder> &fbb) {
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
std::string reason = "Current amount for startFLJob has reached the threshold. Please startFLJob later.";
BuildStartFLJobRsp(fbb, schema::ResponseCode_OutOfTime, reason, false,
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
MS_LOG(ERROR) << reason;
return true;
}
return false;
}
DeviceMeta StartFLJobKernel::CreateDeviceMetadata(const schema::RequestFLJob *start_fl_job_req) {
std::string fl_name = start_fl_job_req->fl_name()->str();
std::string fl_id = start_fl_job_req->fl_id()->str();
int data_size = start_fl_job_req->data_size();
MS_LOG(INFO) << "DeviceMeta fl_name:" << fl_name << ", fl_id:" << fl_id << ", data_size:" << data_size;
DeviceMeta device_meta;
device_meta.set_fl_name(fl_name);
device_meta.set_fl_id(fl_id);
device_meta.set_data_size(data_size);
return device_meta;
}
bool StartFLJobKernel::ReadyForStartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta) {
bool ret = true;
std::string reason = "";
if (device_meta.data_size() < 1) {
reason = "FL job data size is not enough.";
ret = false;
}
if (!ret) {
BuildStartFLJobRsp(fbb, schema::ResponseCode_NotSelected, reason, false,
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
MS_LOG(ERROR) << reason;
}
return ret;
}
bool StartFLJobKernel::CountForStartFLJob(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestFLJob *start_fl_job_req) {
if (!DistributedCountService::GetInstance().Count(name_, start_fl_job_req->fl_id()->str())) {
std::string reason = "startFLJob counting failed.";
BuildStartFLJobRsp(fbb, schema::ResponseCode_OutOfTime, reason, false,
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
MS_LOG(ERROR) << reason;
return false;
}
return true;
}
void StartFLJobKernel::StartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta) {
PBMetadata metadata;
*metadata.mutable_device_meta() = device_meta;
DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxDeviceMetas, metadata);
std::map<std::string, AddressPtr> feature_maps = executor_->GetModel();
BuildStartFLJobRsp(fbb, schema::ResponseCode_SUCCEED, "success", true,
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_), feature_maps);
return;
}
void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
const std::string &reason, const bool is_selected,
const std::string &next_req_time,
std::map<std::string, AddressPtr> feature_maps) {
auto fbs_reason = fbb->CreateString(reason);
auto fbs_next_req_time = fbb->CreateString(next_req_time);
auto fbs_fl_name = fbb->CreateString(PSContext::instance()->fl_name());
schema::FLPlanBuilder fl_plan_builder(*(fbb.get()));
fl_plan_builder.add_fl_name(fbs_fl_name);
fl_plan_builder.add_iterations(PSContext::instance()->fl_iteration_num());
fl_plan_builder.add_epochs(PSContext::instance()->client_epoch_num());
fl_plan_builder.add_mini_batch(PSContext::instance()->client_batch_size());
auto fbs_fl_plan = fl_plan_builder.Finish();
std::vector<flatbuffers::Offset<schema::FeatureMap>> fbs_feature_maps;
for (auto feature_map : feature_maps) {
auto fbs_weight_fullname = fbb->CreateString(feature_map.first);
auto fbs_weight_data =
fbb->CreateVector(reinterpret_cast<float *>(feature_map.second->addr), feature_map.second->size / sizeof(float));
auto fbs_feature_map = schema::CreateFeatureMap(*(fbb.get()), fbs_weight_fullname, fbs_weight_data);
fbs_feature_maps.push_back(fbs_feature_map);
}
auto fbs_feature_maps_vector = fbb->CreateVector(fbs_feature_maps);
schema::ResponseFLJobBuilder rsp_fl_job_builder(*(fbb.get()));
rsp_fl_job_builder.add_retcode(retcode);
rsp_fl_job_builder.add_reason(fbs_reason);
rsp_fl_job_builder.add_iteration(LocalMetaStore::GetInstance().curr_iter_num());
rsp_fl_job_builder.add_is_selected(is_selected);
rsp_fl_job_builder.add_next_req_time(fbs_next_req_time);
rsp_fl_job_builder.add_fl_plan_config(fbs_fl_plan);
rsp_fl_job_builder.add_feature_map(fbs_feature_maps_vector);
auto rsp_fl_job = rsp_fl_job_builder.Finish();
fbb->Finish(rsp_fl_job);
return;
}
REG_ROUND_KERNEL(startFLJob, StartFLJobKernel)
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace mindspore

View File

@ -0,0 +1,74 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_START_FL_JOB_KERNEL_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_START_FL_JOB_KERNEL_H_
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "ps/server/common.h"
#include "ps/server/executor.h"
#include "ps/server/kernel/round/round_kernel.h"
#include "ps/server/kernel/round/round_kernel_factory.h"
namespace mindspore {
namespace ps {
namespace server {
namespace kernel {
class StartFLJobKernel : public RoundKernel {
public:
StartFLJobKernel() = default;
~StartFLJobKernel() override = default;
void InitKernel(size_t threshold_count) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
bool Reset() override;
private:
// Returns whether the startFLJob count of this iteration has reached the threshold.
bool ReachThresholdForStartFLJob(const std::shared_ptr<FBBuilder> &fbb);
// The metadata of device will be stored and queried in updateModel round.
DeviceMeta CreateDeviceMetadata(const schema::RequestFLJob *start_fl_job_req);
// Returns whether the request is valid for startFLJob.For now, the condition is simple. We will add more conditions
// to device in later versions.
bool ReadyForStartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta);
// Distributed count service counts for startFLJob.
bool CountForStartFLJob(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestFLJob *start_fl_job_req);
void StartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta);
// Build response for startFLJob round no matter success or failure.
void BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
const std::string &reason, const bool is_selected, const std::string &next_req_time,
std::map<std::string, AddressPtr> feature_maps = {});
// The executor is for getting the initial model for startFLJob request.
Executor *executor_;
// The time window of one iteration.
size_t iteration_time_window_;
};
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_START_FL_JOB_KERNEL_H_

View File

@ -14,30 +14,29 @@
* limitations under the License. * limitations under the License.
*/ */
#include "ps/server/local_meta_storage.h" #include "ps/server/local_meta_store.h"
#include <string>
namespace mindspore { namespace mindspore {
namespace ps { namespace ps {
namespace server { namespace server {
void LocalMetaStorage::remove_value(const std::string &name) { void LocalMetaStore::remove_value(const std::string &name) {
std::unique_lock<std::mutex> lock(mtx_); std::unique_lock<std::mutex> lock(mtx_);
if (key_to_meta_.count(name) != 0) { if (key_to_meta_.count(name) != 0) {
key_to_meta_.erase(key_to_meta_.find(name)); key_to_meta_.erase(key_to_meta_.find(name));
} }
} }
bool LocalMetaStorage::has_value(const std::string &name) { bool LocalMetaStore::has_value(const std::string &name) {
std::unique_lock<std::mutex> lock(mtx_); std::unique_lock<std::mutex> lock(mtx_);
return key_to_meta_.count(name) != 0; return key_to_meta_.count(name) != 0;
} }
void LocalMetaStorage::set_curr_iter_num(size_t num) { void LocalMetaStore::set_curr_iter_num(size_t num) {
std::unique_lock<std::mutex> lock(mtx_); std::unique_lock<std::mutex> lock(mtx_);
curr_iter_num_ = num; curr_iter_num_ = num;
} }
const size_t LocalMetaStorage::curr_iter_num() { const size_t LocalMetaStore::curr_iter_num() {
std::unique_lock<std::mutex> lock(mtx_); std::unique_lock<std::mutex> lock(mtx_);
return curr_iter_num_; return curr_iter_num_;
} }

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORAGE_H_ #ifndef MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORE_H_
#define MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORAGE_H_ #define MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORE_H_
#include <any> #include <any>
#include <mutex> #include <mutex>
@ -26,13 +26,13 @@
namespace mindspore { namespace mindspore {
namespace ps { namespace ps {
namespace server { namespace server {
// LocalMetaStorage class is used for metadata storage of this server process. // LocalMetaStore class is used for metadata storage of this server process.
// For example, the current iteration number, time windows for round kernels, etc. // For example, the current iteration number, time windows for round kernels, etc.
// LocalMetaStorage is threadsafe. // LocalMetaStore is threadsafe.
class LocalMetaStorage { class LocalMetaStore {
public: public:
static LocalMetaStorage &GetInstance() { static LocalMetaStore &GetInstance() {
static LocalMetaStorage instance; static LocalMetaStore instance;
return instance; return instance;
} }
@ -43,7 +43,7 @@ class LocalMetaStorage {
} }
template <typename T> template <typename T>
const T &value(const std::string &name) { T value(const std::string &name) {
std::unique_lock<std::mutex> lock(mtx_); std::unique_lock<std::mutex> lock(mtx_);
try { try {
T value = std::any_cast<T>(key_to_meta_[name]); T value = std::any_cast<T>(key_to_meta_[name]);
@ -71,10 +71,10 @@ class LocalMetaStorage {
const size_t curr_iter_num(); const size_t curr_iter_num();
private: private:
LocalMetaStorage() = default; LocalMetaStore() = default;
~LocalMetaStorage() = default; ~LocalMetaStore() = default;
LocalMetaStorage(const LocalMetaStorage &) = delete; LocalMetaStore(const LocalMetaStore &) = delete;
LocalMetaStorage &operator=(const LocalMetaStorage &) = delete; LocalMetaStore &operator=(const LocalMetaStore &) = delete;
// key_to_meta_ stores metadata with key-value format. // key_to_meta_ stores metadata with key-value format.
std::unordered_map<std::string, std::any> key_to_meta_; std::unordered_map<std::string, std::any> key_to_meta_;
@ -85,4 +85,4 @@ class LocalMetaStorage {
} // namespace server } // namespace server
} // namespace ps } // namespace ps
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORAGE_H_ #endif // MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORE_H_

View File

@ -0,0 +1,144 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ps/server/model_store.h"
#include <map>
#include <string>
#include <memory>
#include "ps/server/executor.h"
namespace mindspore {
namespace ps {
namespace server {
void ModelStore::Init(uint32_t max_count) {
if (!Executor::GetInstance().initialized()) {
MS_LOG(EXCEPTION) << "Server's executor must be initialized before model storage.";
return;
}
max_model_count_ = max_count;
iteration_to_model_[kInitIterationNum] = AssignNewModelMemory();
model_size_ = ComputeModelSize();
}
bool ModelStore::StoreModelByIterNum(size_t iteration, const std::map<std::string, AddressPtr> &new_model) {
if (iteration_to_model_.count(iteration) != 0) {
MS_LOG(WARNING) << "Model for iteration " << iteration << " is already stored";
return false;
}
if (new_model.empty()) {
MS_LOG(ERROR) << "Model feature map is empty.";
return false;
}
std::shared_ptr<MemoryRegister> memory_register;
if (iteration_to_model_.size() < max_model_count_) {
// If iteration_to_model_.size() is not max_model_count_, need to assign new memory for the model.
memory_register = AssignNewModelMemory();
if (memory_register == nullptr) {
MS_LOG(ERROR) << "Memory for the new model is nullptr.";
return false;
}
iteration_to_model_[iteration] = memory_register;
} else {
// If iteration_to_model_ size is already max_model_count_, we need to replace earliest model with the newest model.
memory_register = iteration_to_model_.begin()->second;
if (memory_register == nullptr) {
MS_LOG(ERROR) << "Earliest model is nullptr.";
return false;
}
iteration_to_model_.erase(iteration_to_model_.begin());
}
// Copy new model data to the the stored model.
auto &stored_model = memory_register->addresses();
for (const auto &weight : new_model) {
const std::string &weight_name = weight.first;
if (stored_model.count(weight_name) != 0) {
MS_LOG(ERROR) << "The stored model has no weight " << weight_name;
continue;
}
void *dst_addr = stored_model[weight_name]->addr;
size_t dst_size = stored_model[weight_name]->size;
void *src_addr = weight.second->addr;
size_t src_size = weight.second->size;
int ret = memcpy_s(dst_addr, dst_size, src_addr, src_size);
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
return false;
}
}
iteration_to_model_[iteration] = memory_register;
return true;
}
std::map<std::string, AddressPtr> ModelStore::GetModelByIterNum(size_t iteration) {
std::map<std::string, AddressPtr> model = {};
if (iteration_to_model_.count(iteration) == 0) {
MS_LOG(ERROR) << "Model for iteration " << iteration << " is not stored.";
return model;
}
model = iteration_to_model_[iteration]->addresses();
return model;
}
const std::map<size_t, std::shared_ptr<MemoryRegister>> &ModelStore::iteration_to_model() const {
return iteration_to_model_;
}
size_t ModelStore::model_size() const { return model_size_; }
std::shared_ptr<MemoryRegister> ModelStore::AssignNewModelMemory() {
std::map<std::string, AddressPtr> model = Executor::GetInstance().GetModel();
if (model.empty()) {
MS_LOG(EXCEPTION) << "Model feature map is empty.";
return nullptr;
}
// Assign new memory for the model.
std::shared_ptr<MemoryRegister> memory_register = std::make_shared<MemoryRegister>();
for (const auto &weight : model) {
const std::string weight_name = weight.first;
size_t weight_size = weight.second->size;
auto weight_data = std::make_unique<char[]>(weight_size);
if (weight_data == nullptr) {
MS_LOG(EXCEPTION) << "Assign memory for weight failed.";
return nullptr;
}
memory_register->RegisterArray(weight_name, &weight_data, weight_size);
}
return memory_register;
}
size_t ModelStore::ComputeModelSize() {
if (iteration_to_model_.empty()) {
MS_LOG(EXCEPTION) << "Calculating model size failed: model for iteration 0 is not stored yet. ";
return 0;
}
const auto &model = iteration_to_model_[kInitIterationNum];
MS_EXCEPTION_IF_NULL(model);
size_t model_size = std::accumulate(model->addresses().begin(), model->addresses().end(), static_cast<size_t>(0),
[](size_t s, const auto &weight) { return s + weight.second->size; });
MS_LOG(INFO) << "Model size in byte is " << model_size;
return model_size;
}
} // namespace server
} // namespace ps
} // namespace mindspore

View File

@ -0,0 +1,78 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_MODEL_STORE_H_
#define MINDSPORE_CCSRC_PS_SERVER_MODEL_STORE_H_
#include <map>
#include <memory>
#include <string>
#include "ps/server/common.h"
#include "ps/server/memory_register.h"
#include "ps/server/executor.h"
namespace mindspore {
namespace ps {
namespace server {
// The initial iteration number is 0 in server.
constexpr size_t kInitIterationNum = 0;
// Server framework use ModelStore to store and query models.
// ModelStore stores multiple models because worker could get models of the previous iterations.
class ModelStore {
public:
static ModelStore &GetInstance() {
static ModelStore instance;
return instance;
}
// Initialize ModelStore with max count of models need to be stored.
void Init(uint32_t max_count = 3);
// Store the model of the given iteration. The model is acquired from Executor. If the current model count is already
// max_model_count_, the earliest model will be replaced.
bool StoreModelByIterNum(size_t iteration, const std::map<std::string, AddressPtr> &model);
// Get model of the given iteration.
std::map<std::string, AddressPtr> GetModelByIterNum(size_t iteration);
// Returns all models stored in ModelStore.
const std::map<size_t, std::shared_ptr<MemoryRegister>> &iteration_to_model() const;
// Returns the model size, which could be calculated at the initializing phase.
size_t model_size() const;
private:
ModelStore() : max_model_count_(0), model_size_(0), iteration_to_model_({}) {}
~ModelStore() = default;
ModelStore(const ModelStore &) = delete;
ModelStore &operator=(const ModelStore &) = delete;
// To store multiple models, new memory must assigned. The max memory size assigned for models is max_model_count_ *
// model_size_.
std::shared_ptr<MemoryRegister> AssignNewModelMemory();
// Calculate the model size. This method should be called after iteration_to_model_ is initialized.
size_t ComputeModelSize();
size_t max_model_count_;
size_t model_size_;
std::map<size_t, std::shared_ptr<MemoryRegister>> iteration_to_model_;
};
} // namespace server
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_MODEL_STORE_H_

View File

@ -25,15 +25,15 @@
namespace mindspore { namespace mindspore {
namespace ps { namespace ps {
namespace server { namespace server {
bool ParameterAggregator::Init(const CNodePtr &cnode, size_t required_count) { bool ParameterAggregator::Init(const CNodePtr &cnode, size_t threshold_count) {
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
memory_register_ = std::make_shared<MemoryRegister>(); memory_register_ = std::make_shared<MemoryRegister>();
MS_EXCEPTION_IF_NULL(memory_register_); MS_EXCEPTION_IF_NULL(memory_register_);
required_push_count_ = required_count; required_push_count_ = threshold_count;
// The required_pull_count_ is the count for Pull, which should be the same as required_push_count_. // The required_pull_count_ is the count for Pull, which should be the same as required_push_count_.
// required_pull_count_ normally used in parameter server training mode. // required_pull_count_ normally used in parameter server training mode.
required_pull_count_ = required_count; required_pull_count_ = threshold_count;
MS_LOG(DEBUG) << "Start initializing kernels for " << AnfAlgo::GetCNodeName(cnode); MS_LOG(DEBUG) << "Start initializing kernels for " << AnfAlgo::GetCNodeName(cnode);
InitAggregationKernels(cnode); InitAggregationKernels(cnode);

View File

@ -61,8 +61,8 @@ class ParameterAggregator {
~ParameterAggregator() = default; ~ParameterAggregator() = default;
// Initialize ParameterAggregator with a cnode. This cnode is normally a optimizer kernel for now. // Initialize ParameterAggregator with a cnode. This cnode is normally a optimizer kernel for now.
// The parameter required_count helps ParameterAggregator to judge the current status if it's stateful. // The parameter threshold_count helps ParameterAggregator to judge the current status if it's stateful.
bool Init(const CNodePtr &cnode, size_t required_count = 0); bool Init(const CNodePtr &cnode, size_t threshold_count = 0);
// Update old data stored in ParameterAggregator with new data. // Update old data stored in ParameterAggregator with new data.
// The data could have many meanings: weights, gradients, learning_rate, momentum, etc. // The data could have many meanings: weights, gradients, learning_rate, momentum, etc.

View File

@ -0,0 +1,139 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ps/server/round.h"
#include <memory>
#include <string>
namespace mindspore {
namespace ps {
namespace server {
Round::Round(const std::string &name, bool check_timeout, size_t time_window, bool check_count, size_t threshold_count)
: name_(name),
check_timeout_(check_timeout),
time_window_(time_window),
check_count_(check_count),
threshold_count_(threshold_count) {}
void Round::Initialize(const std::shared_ptr<core::CommunicatorBase> &communicator, TimeOutCb timeout_cb,
FinishIterCb finish_iteration_cb) {
MS_EXCEPTION_IF_NULL(communicator);
communicator_ = communicator;
// Register callback for round kernel.
communicator_->RegisterMsgCallBack(
name_, [&](std::shared_ptr<core::MessageHandler> message) { LaunchRoundKernel(message); });
// Callback when the iteration is finished.
finish_iteration_cb_ = [this, finish_iteration_cb](void) -> void {
MS_LOG(INFO) << "Round " << name_ << " finished! Proceed to next iteration.";
finish_iteration_cb();
};
// Callback for finalizing the server. This can only be called once.
finalize_cb_ = [&](void) -> void { communicator_->Stop(); };
if (check_timeout_) {
iter_timer_ = std::make_shared<IterationTimer>();
// 1.Set the timeout callback for the timer.
iter_timer_->SetTimeOutCallBack([this, timeout_cb](void) -> void {
MS_LOG(INFO) << "Round " << name_ << " timeout! Proceed to next iteration.";
timeout_cb();
});
// 2.Stopping timer callback which will be set to the round kernel.
stop_timer_cb_ = [&](void) -> void {
MS_LOG(INFO) << "Round " << name_ << " kernel stops its timer.";
iter_timer_->Stop();
};
}
// Set counter event callbacks for this round if the round kernel is stateful.
if (check_count_) {
auto first_count_handler = std::bind(&Round::OnFirstCountEvent, this, std::placeholders::_1);
auto last_count_handler = std::bind(&Round::OnLastCountEvent, this, std::placeholders::_1);
DistributedCountService::GetInstance().RegisterCounter(name_, threshold_count_,
{first_count_handler, last_count_handler});
}
}
void Round::BindRoundKernel(const std::shared_ptr<kernel::RoundKernel> &kernel) {
MS_EXCEPTION_IF_NULL(kernel);
kernel_ = kernel;
kernel_->set_stop_timer_cb(stop_timer_cb_);
kernel_->set_finish_iteration_cb(finish_iteration_cb_);
return;
}
void Round::LaunchRoundKernel(const std::shared_ptr<core::MessageHandler> &message) {
if (message == nullptr) {
MS_LOG(ERROR) << "Message is nullptr.";
return;
}
AddressPtr input = std::make_shared<Address>();
AddressPtr output = std::make_shared<Address>();
input->addr = message->data();
input->size = message->len();
bool ret = kernel_->Launch({input}, {}, {output});
if (output->size == 0) {
std::string reason = "The output of the round " + name_ + " is empty.";
MS_LOG(WARNING) << reason;
communicator_->SendResponse(reason.c_str(), reason.size(), message);
return;
}
// Must send response back no matter what value Launch method returns.
if (!ret) {
MS_LOG(WARNING) << "Launching round kernel of round " << name_ << " failed.";
}
communicator_->SendResponse(output->addr, output->size, message);
kernel_->Release(output);
return;
}
void Round::Reset() { kernel_->Reset(); }
const std::string &Round::name() const { return name_; }
size_t Round::threshold_count() const { return threshold_count_; }
size_t Round::time_window() const { return time_window_; }
void Round::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &) {
MS_LOG(INFO) << "Round " << name_ << " first count event is triggered.";
// The timer starts only after the first count event is triggered by DistributedCountService.
if (check_timeout_) {
iter_timer_->Start(std::chrono::milliseconds(time_window_));
}
return;
}
void Round::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) {
MS_LOG(INFO) << "Round " << name_ << " last count event is triggered.";
// Same as the first count event, the timer must be stopped by DistributedCountService.
if (check_timeout_) {
iter_timer_->Stop();
}
// Some kernels override the OnLastCountEvent method.
kernel_->OnLastCountEvent(message);
return;
}
} // namespace server
} // namespace ps
} // namespace mindspore

View File

@ -0,0 +1,95 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_ROUND_H_
#define MINDSPORE_CCSRC_PS_SERVER_ROUND_H_
#include <memory>
#include <string>
#include "ps/core/communicator/communicator_base.h"
#include "ps/server/common.h"
#include "ps/server/iteration_timer.h"
#include "ps/server/distributed_count_service.h"
#include "ps/server/kernel/round/round_kernel.h"
namespace mindspore {
namespace ps {
namespace server {
// Round helps server to handle network round messages and launch round kernels. One iteration in server consists of
// multiple rounds like startFLJob, updateModel, Push, Pull, etc. Some round kernels may be stateful because of counting
// and timing. So Round helps register counter and timer so that the round kernels only need to focus on the logic.
class Round {
public:
explicit Round(const std::string &name, bool check_timeout = true, size_t time_window = 3000,
bool check_count = false, size_t threshold_count = 8);
~Round() = default;
void Initialize(const std::shared_ptr<core::CommunicatorBase> &communicator, TimeOutCb timeout_cb,
FinishIterCb finish_iteration_cb);
// Bind a round kernel to this Round. This method should be called after Initialize.
void BindRoundKernel(const std::shared_ptr<kernel::RoundKernel> &kernel);
// This method is the callback which will be set to the communicator and called after the corresponding round message
// is sent to the server.
void LaunchRoundKernel(const std::shared_ptr<core::MessageHandler> &message);
// Round needs to be reset after each iteration is finished or its timer expires.
void Reset();
const std::string &name() const;
size_t threshold_count() const;
size_t time_window() const;
private:
// The callbacks which will be set to DistributedCounterService.
void OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message);
void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message);
std::string name_;
// Whether this round needs to use timer. Most rounds in federated learning with mobile devices scenario need to set
// check_timeout_ to true.
bool check_timeout_;
// The time window duration for this round when check_timeout_ is set to true.
size_t time_window_;
// If check_count_ is true, it means the round has to do counting for every round message and the first/last count
// event will be triggered.
bool check_count_;
// The threshold count for this round when check_count_ is set to true. The logic of this round has to check whether
// the round message count has reached threshold_count_.
size_t threshold_count_;
std::shared_ptr<core::CommunicatorBase> communicator_;
// The round kernel for this Round.
std::shared_ptr<kernel::RoundKernel> kernel_;
// Some rounds may need timer to eliminate the long tail effect.
std::shared_ptr<IterationTimer> iter_timer_;
// The callbacks which will be set to the round kernel.
StopTimerCb stop_timer_cb_;
FinishIterCb finish_iteration_cb_;
FinalizeCb finalize_cb_;
};
} // namespace server
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_ROUND_H_

View File

@ -31,7 +31,7 @@
#include "utils/utils.h" #include "utils/utils.h"
#include "frontend/parallel/context.h" #include "frontend/parallel/context.h"
#include "debug/env_config_parser.h" #include "debug/env_config_parser.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
#include "ps/ps_cache/ps_cache_manager.h" #include "ps/ps_cache/ps_cache_manager.h"
#endif #endif
@ -307,7 +307,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
} }
need_alloc_nodes.push_back(item); need_alloc_nodes.push_back(item);
} }
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
bool ps_cache_check = false; bool ps_cache_check = false;
#endif #endif
for (auto &item : need_alloc_nodes) { for (auto &item : need_alloc_nodes) {
@ -320,7 +320,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
continue; continue;
} }
DeviceAddressPtr device_address = nullptr; DeviceAddressPtr device_address = nullptr;
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
const std::string &param_name = item->fullname_with_scope(); const std::string &param_name = item->fullname_with_scope();
if (ps::ps_cache_instance.IsHashTable(param_name)) { if (ps::ps_cache_instance.IsHashTable(param_name)) {
MS_LOG(INFO) << "Parameter(" << param_name << ")" MS_LOG(INFO) << "Parameter(" << param_name << ")"
@ -1038,7 +1038,7 @@ DeviceAddressPtr KernelRuntime::AssignSingleOpLaunchMemory(size_t size, const st
return device_address; return device_address;
} }
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph, void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph,
AnfNodePtr *const first_cache_input_index, AnfNodePtr *const first_cache_input_index,
size_t *const first_cache_size) { size_t *const first_cache_size) {

View File

@ -142,7 +142,7 @@ class KernelRuntime {
void RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value, session::KernelGraph *graph); void RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value, session::KernelGraph *graph);
void AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, size_t output_idx); void AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, size_t output_idx);
DeviceAddressPtr PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index); DeviceAddressPtr PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index);
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
void GetFirstPSEmbeddingCache(const session::KernelGraph *graph, AnfNodePtr *const first_cache_input_index, void GetFirstPSEmbeddingCache(const session::KernelGraph *graph, AnfNodePtr *const first_cache_input_index,
size_t *const first_cache_size); size_t *const first_cache_size);
void CheckIfSupportPSEmbeddingCache(const session::KernelGraph *graph); void CheckIfSupportPSEmbeddingCache(const session::KernelGraph *graph);

View File

@ -16,14 +16,14 @@
#include "runtime/device/kernel_runtime_manager.h" #include "runtime/device/kernel_runtime_manager.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
#include "ps/ps_cache/ps_cache_manager.h" #include "ps/ps_cache/ps_cache_manager.h"
#endif #endif
namespace mindspore { namespace mindspore {
namespace device { namespace device {
void KernelRuntimeManager::ClearRuntimeResource() { void KernelRuntimeManager::ClearRuntimeResource() {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) { if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) {
ps::ps_cache_instance.SyncEmbeddingTable(); ps::ps_cache_instance.SyncEmbeddingTable();
} }
@ -125,7 +125,7 @@ void KernelRuntimeManager::ReleaseKernelRuntime(const std::string &device_name,
if (runtime == nullptr) { if (runtime == nullptr) {
return; return;
} }
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && !_WIN32)
if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) { if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) {
ps::ps_cache_instance.SyncEmbeddingTable(); ps::ps_cache_instance.SyncEmbeddingTable();
} }

123
mindspore/schema/cipher.fbs Normal file
View File

@ -0,0 +1,123 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
namespace mindspore.schema;
table CipherPublicParams {
t:int;
p:[ubyte];
g:int;
prime:[ubyte];
dp_eps:float;
dp_delta:float;
dp_norm_clip:float;
encrypt_type:int;
}
table ClientPublicKeys {
fl_id:string;
c_pk:[ubyte];
s_pk: [ubyte];
}
table ClientShare {
fl_id:string;
share:[ubyte];
index:int;
}
table RequestExchangeKeys{
fl_id:string;
c_pk:[ubyte];
s_pk:[ubyte];
iteration:int;
timestamp:string;
}
table ResponseExchangeKeys{
retcode:int;
reason:string;
next_req_time:string;
iteration:int;
}
table GetExchangeKeys{
fl_id:string;
iteration:int;
timestamp:string;
}
table ReturnExchangeKeys{
retcode:int;
iteration:int;
remote_publickeys:[ClientPublicKeys];
next_req_time:string;
}
table RequestShareSecrets{
fl_id:string;
encrypted_shares:[ClientShare];
iteration:int;
timestamp:string;
}
table ResponseShareSecrets{
retcode:int;
reason:string;
next_req_time:string;
iteration:int;
}
table GetShareSecrets{
fl_id:string;
iteration:int;
timestamp:string;
}
table ReturnShareSecrets{
retcode:int;
iteration:int;
encrypted_shares: [ClientShare];
next_req_time:string;
}
table GetClientList{
fl_id:string;
iteration:int;
timestamp:string;
}
table ReturnClientList{
retcode:int;
reason:string;
clients:[string];
iteration:int;
next_req_time:string;
}
table SendReconstructSecret{
fl_id:string;
reconstruct_secret_shares:[ClientShare];
iteration:int;
timestamp:string;
}
table ReconstructSecret{
retcode:int;
reason:string;
iteration:int;
next_req_time:string;
}

159
mindspore/schema/fl_job.fbs Normal file
View File

@ -0,0 +1,159 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
include "cipher.fbs";
namespace mindspore.schema;
file_identifier "FLJ0";
file_extension "fl";
enum ResponseCode: int {
SUCCEED=200,
SucNotReady=201,
RepeatRequest=202,
SucNotMatch=204,
OutOfTime=300,
NotSelected=301,
RequestError=400,
SystemError=500
}
enum AggregationType:byte {FedAvg=0, FedAdam = 1, FedAdagrag=2, FedMeta=3, qffl=4}
enum Metrics:byte {accuracy = 0, precision = 1, recall = 2, AUC = 3,f1=4, fbeta=5}
enum EarlyStopType:byte {loss_diff = 0, loss_abs = 1, weight_diff = 2}
table Aggregation {
type:AggregationType;
weights:[float];
}
table EarlyStop {
early_stop_type:EarlyStopType;
weight:float;
rounds:int;
}
table FeatureMap{
weight_fullname:string;
data:[float];
}
table RequestFLJob{
fl_name:string;
fl_id:string;
iteration:int;
data_size:int;
timestamp:string;
}
table ResponseFLJob {
retcode:int;
reason:string;
iteration:int;
is_selected:bool = false;
next_req_time:string;
fl_plan_config:FLPlan;
feature_map:[FeatureMap];
timestamp:string;
}
table FLPlan {
fl_name:string;
iterations:int;
epochs:int;
early_stop:EarlyStop;
mini_batch:int;
shuffle:bool = false;
lr:float;
aggregation:Aggregation;
metrics:[Metrics];
cipher:CipherPublicParams;
}
table RequestUpdateModel{
fl_name:string;
fl_id:string;
iteration:int;
feature_map:[FeatureMap];
timestamp:string;
}
table ResponseUpdateModel{
retcode:int;
reason:string;
feature_map:[FeatureMap];
next_req_time:string;
timestamp:string;
}
table RequestAsyncUpdateModel{
fl_name:string;
fl_id:string;
iteration:int;
data_size:int;
feature_map:[FeatureMap];
}
table ResponseAsyncUpdateModel{
retcode:int;
reason:string;
iteration:int;
}
table RequestOverwriteWeightsByKey{
iteration:int;
feature_map:[FeatureMap];
}
table ResponseOverwriteWeightsByKey{
retcode:int;
reason:string;
}
table RequestGetModel{
fl_name:string;
iteration:int;
timestamp:string;
}
table ResponseGetModel{
retcode:int;
reason:string;
iteration:int;
feature_map:[FeatureMap];
timestamp:string;
}
table RequestAsyncGetModel{
fl_name:string;
iteration:int;
}
table ResponseAsyncGetModel{
retcode:int;
reason:string;
iteration:int;
feature_map:[FeatureMap];
}
table RequestGetWeightsByKey{
iteration:int;
weight_names:[string];
}
table ResponseGetWeightsByKey{
retcode:int;
reason:string;
feature_map:[FeatureMap];
}
// FeatureMapList refers to the whole trained model.
table FeatureMapList {
feature_map:[FeatureMap];
}