forked from mindspore-Ecosystem/mindspore
Add server code part2
This commit is contained in:
parent
cb6e055736
commit
12f95b51f4
|
@ -35,7 +35,7 @@ function(ms_build_flatbuffers source_schema_files
|
||||||
set(total_schema_dirs -I ${schema_dir} ${total_schema_dirs})
|
set(total_schema_dirs -I ${schema_dir} ${total_schema_dirs})
|
||||||
endforeach()
|
endforeach()
|
||||||
|
|
||||||
foreach(schema ${source_schema_files})
|
foreach(schema IN LISTS ${source_schema_files})
|
||||||
get_filename_component(filename ${schema} NAME_WE)
|
get_filename_component(filename ${schema} NAME_WE)
|
||||||
if(NOT ${generated_output_dir} STREQUAL "")
|
if(NOT ${generated_output_dir} STREQUAL "")
|
||||||
set(generated_file ${generated_output_dir}/${filename}_generated.h)
|
set(generated_file ${generated_output_dir}/${filename}_generated.h)
|
||||||
|
|
|
@ -212,7 +212,7 @@ if(ENABLE_GPU)
|
||||||
)
|
)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))
|
if(ENABLE_CPU AND NOT WIN32)
|
||||||
install(
|
install(
|
||||||
TARGETS ps_cache
|
TARGETS ps_cache
|
||||||
DESTINATION ${INSTALL_LIB_DIR}
|
DESTINATION ${INSTALL_LIB_DIR}
|
||||||
|
|
|
@ -373,7 +373,7 @@ elseif(CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
||||||
target_link_libraries(mindspore mindspore_gvar)
|
target_link_libraries(mindspore mindspore_gvar)
|
||||||
target_link_libraries(_c_expression PRIVATE -Wl,-force_load mindspore mindspore_core -Wl,-noall_load)
|
target_link_libraries(_c_expression PRIVATE -Wl,-force_load mindspore mindspore_core -Wl,-noall_load)
|
||||||
else()
|
else()
|
||||||
if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))
|
if(ENABLE_CPU AND NOT WIN32)
|
||||||
target_link_libraries(mindspore proto_input mindspore::protobuf
|
target_link_libraries(mindspore proto_input mindspore::protobuf
|
||||||
mindspore::event mindspore::event_pthreads mindspore::event_openssl mindspore::json)
|
mindspore::event mindspore::event_pthreads mindspore::event_openssl mindspore::json)
|
||||||
target_link_libraries(mindspore -Wl,--no-as-needed mindspore::event_core ps_cache)
|
target_link_libraries(mindspore -Wl,--no-as-needed mindspore::event_core ps_cache)
|
||||||
|
|
|
@ -75,7 +75,7 @@ if(ENABLE_CPU)
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)))
|
if(NOT ENABLE_CPU OR WIN32)
|
||||||
list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/apply_momentum_ps_kernel.cc")
|
list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/apply_momentum_ps_kernel.cc")
|
||||||
list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/embedding_look_up_proxy_kernel.cc")
|
list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/embedding_look_up_proxy_kernel.cc")
|
||||||
list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/embedding_look_up_ps_kernel.cc")
|
list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/embedding_look_up_ps_kernel.cc")
|
||||||
|
|
|
@ -421,7 +421,7 @@ void AscendSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_gra
|
||||||
size = abstract::ShapeSize(shape_tmp) * abstract::TypeIdSize(tensor->data_type());
|
size = abstract::ShapeSize(shape_tmp) * abstract::TypeIdSize(tensor->data_type());
|
||||||
}
|
}
|
||||||
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) {
|
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) {
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
const std::string ¶m_name = input_node->fullname_with_scope();
|
const std::string ¶m_name = input_node->fullname_with_scope();
|
||||||
if (ps::ps_cache_instance.IsHashTable(param_name)) {
|
if (ps::ps_cache_instance.IsHashTable(param_name)) {
|
||||||
continue;
|
continue;
|
||||||
|
|
|
@ -33,7 +33,7 @@
|
||||||
#include "debug/anf_ir_dump.h"
|
#include "debug/anf_ir_dump.h"
|
||||||
#include "debug/dump_proto.h"
|
#include "debug/dump_proto.h"
|
||||||
#include "debug/data_dump/dump_json_parser.h"
|
#include "debug/data_dump/dump_json_parser.h"
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
#include "ps/util.h"
|
#include "ps/util.h"
|
||||||
#include "ps/ps_context.h"
|
#include "ps/ps_context.h"
|
||||||
#endif
|
#endif
|
||||||
|
@ -74,7 +74,7 @@ void CPUSession::Reorder(std::vector<CNodePtr> *node_list) { AnfAlgo::ReorderPos
|
||||||
void CPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
void CPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
||||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||||
auto pm = std::make_shared<opt::PassManager>();
|
auto pm = std::make_shared<opt::PassManager>();
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
auto ms_context = MsContext::GetInstance();
|
auto ms_context = MsContext::GetInstance();
|
||||||
MS_EXCEPTION_IF_NULL(ms_context);
|
MS_EXCEPTION_IF_NULL(ms_context);
|
||||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode && ps::PSContext::instance()->is_ps_mode()) {
|
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode && ps::PSContext::instance()->is_ps_mode()) {
|
||||||
|
@ -174,7 +174,7 @@ void CPUSession::PreExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_grap
|
||||||
MS_LOG(INFO) << "Bind input output address";
|
MS_LOG(INFO) << "Bind input output address";
|
||||||
runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs);
|
runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs);
|
||||||
|
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
InitPSParamAndOptim(kernel_graph, inputs);
|
InitPSParamAndOptim(kernel_graph, inputs);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,7 +21,7 @@
|
||||||
#include "utils/comm_manager.h"
|
#include "utils/comm_manager.h"
|
||||||
#include "utils/scoped_long_running.h"
|
#include "utils/scoped_long_running.h"
|
||||||
#include "pybind_api/ir/tensor_py.h"
|
#include "pybind_api/ir/tensor_py.h"
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
#include "ps/ps_cache/ps_cache_manager.h"
|
#include "ps/ps_cache/ps_cache_manager.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
|
@ -43,7 +43,7 @@
|
||||||
#include "debug/common.h"
|
#include "debug/common.h"
|
||||||
#include "utils/trace_base.h"
|
#include "utils/trace_base.h"
|
||||||
#include "frontend/parallel/context.h"
|
#include "frontend/parallel/context.h"
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
#include "ps/ps_cache/ps_cache_manager.h"
|
#include "ps/ps_cache/ps_cache_manager.h"
|
||||||
#include "ps/constants.h"
|
#include "ps/constants.h"
|
||||||
#include "ps/util.h"
|
#include "ps/util.h"
|
||||||
|
@ -2357,7 +2357,7 @@ void SessionBasic::DumpGraph(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
void SessionBasic::InitPsWorker(const KernelGraphPtr &kernel_graph) {
|
void SessionBasic::InitPsWorker(const KernelGraphPtr &kernel_graph) {
|
||||||
if (!ps::PSContext::instance()->is_worker()) {
|
if (!ps::PSContext::instance()->is_worker()) {
|
||||||
return;
|
return;
|
||||||
|
|
|
@ -244,7 +244,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
||||||
std::vector<uint32_t> GetAllReduceSplitIndex();
|
std::vector<uint32_t> GetAllReduceSplitIndex();
|
||||||
virtual std::string GetCommWorldGroup() { return std::string(); }
|
virtual std::string GetCommWorldGroup() { return std::string(); }
|
||||||
void DumpGraph(const std::shared_ptr<KernelGraph> &kernel_graph);
|
void DumpGraph(const std::shared_ptr<KernelGraph> &kernel_graph);
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
void CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) const;
|
void CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) const;
|
||||||
void GetBatchElements(const AnfNodePtr &kernel_node) const;
|
void GetBatchElements(const AnfNodePtr &kernel_node) const;
|
||||||
void InitPsWorker(const KernelGraphPtr &kernel_graph);
|
void InitPsWorker(const KernelGraphPtr &kernel_graph);
|
||||||
|
@ -263,7 +263,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
||||||
#if !defined(_WIN32) && !defined(_WIN64)
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
std::shared_ptr<Debugger> debugger_;
|
std::shared_ptr<Debugger> debugger_;
|
||||||
#endif
|
#endif
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
bool initialized_ps_cache_{false};
|
bool initialized_ps_cache_{false};
|
||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
|
|
|
@ -25,7 +25,7 @@
|
||||||
#include "frontend/parallel/device_matrix.h"
|
#include "frontend/parallel/device_matrix.h"
|
||||||
#include "frontend/parallel/graph_util/generate_graph.h"
|
#include "frontend/parallel/graph_util/generate_graph.h"
|
||||||
#include "frontend/parallel/context.h"
|
#include "frontend/parallel/context.h"
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
#include "ps/ps_cache/ps_cache_manager.h"
|
#include "ps/ps_cache/ps_cache_manager.h"
|
||||||
#include "utils/ms_context.h"
|
#include "utils/ms_context.h"
|
||||||
#endif
|
#endif
|
||||||
|
@ -160,7 +160,7 @@ Status GatherPInfo::GetAttrs() {
|
||||||
if (std::find(inputs_shape_[1].begin(), inputs_shape_[1].end(), -1) != inputs_shape_[1].end()) {
|
if (std::find(inputs_shape_[1].begin(), inputs_shape_[1].end(), -1) != inputs_shape_[1].end()) {
|
||||||
dynamic_shape_indices_ = true;
|
dynamic_shape_indices_ = true;
|
||||||
}
|
}
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
|
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
|
||||||
std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
|
std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
|
||||||
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
|
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
|
||||||
|
@ -617,7 +617,7 @@ Status GatherPInfo::InferBias() {
|
||||||
rank = rank % (params_strategy[0] * params_strategy[1]);
|
rank = rank % (params_strategy[0] * params_strategy[1]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
||||||
bias_ = static_cast<int64_t>(ps::PsCacheManager::GetInstance().cache_indices_lower_bound());
|
bias_ = static_cast<int64_t>(ps::PsCacheManager::GetInstance().cache_indices_lower_bound());
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
|
|
|
@ -28,7 +28,7 @@
|
||||||
#include "frontend/parallel/strategy.h"
|
#include "frontend/parallel/strategy.h"
|
||||||
#include "frontend/parallel/context.h"
|
#include "frontend/parallel/context.h"
|
||||||
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
|
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
#include "ps/ps_cache/ps_cache_manager.h"
|
#include "ps/ps_cache/ps_cache_manager.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -192,7 +192,7 @@ Status UniqueInfo::GenerateStrategies(int64_t stage_id) {
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
Status UniqueInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
Status UniqueInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
||||||
GenerateGraph gen_g = GenerateGraph();
|
GenerateGraph gen_g = GenerateGraph();
|
||||||
if (gen_g.Init(cnode) != SUCCESS) {
|
if (gen_g.Init(cnode) != SUCCESS) {
|
||||||
|
@ -230,7 +230,7 @@ Status UniqueInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
ReplaceGraphPtr UniqueInfo::replace_graph(const CNodePtr &cnode) {
|
ReplaceGraphPtr UniqueInfo::replace_graph(const CNodePtr &cnode) {
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
||||||
auto inputs = cnode->inputs();
|
auto inputs = cnode->inputs();
|
||||||
if (inputs.empty()) {
|
if (inputs.empty()) {
|
||||||
|
|
|
@ -51,7 +51,7 @@ class UniqueInfo : public OperatorInfo {
|
||||||
Status InferMirrorOps() override;
|
Status InferMirrorOps() override;
|
||||||
Status InferForwardCommunication() override { return SUCCESS; }
|
Status InferForwardCommunication() override { return SUCCESS; }
|
||||||
Status InferAsLossDivisor() override { return SUCCESS; }
|
Status InferAsLossDivisor() override { return SUCCESS; }
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
Status ComputeReplaceGraph(const CNodePtr &cnode);
|
Status ComputeReplaceGraph(const CNodePtr &cnode);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
|
@ -47,14 +47,14 @@
|
||||||
#include "ir/anf.h"
|
#include "ir/anf.h"
|
||||||
#include "ir/param_info.h"
|
#include "ir/param_info.h"
|
||||||
#include "ir/tensor.h"
|
#include "ir/tensor.h"
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
#include "ps/util.h"
|
#include "ps/util.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace parallel {
|
namespace parallel {
|
||||||
bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
|
bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
if (ps::Util::IsRoleOfPServer() || ps::Util::IsRoleOfScheduler()) {
|
if (ps::Util::IsRoleOfPServer() || ps::Util::IsRoleOfScheduler()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -46,7 +46,7 @@
|
||||||
#include "utils/ms_context.h"
|
#include "utils/ms_context.h"
|
||||||
#include "utils/symbolic.h"
|
#include "utils/symbolic.h"
|
||||||
#include "mindspore/core/utils/parallel_node_check.h"
|
#include "mindspore/core/utils/parallel_node_check.h"
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
#include "ps/util.h"
|
#include "ps/util.h"
|
||||||
#include "ps/ps_context.h"
|
#include "ps/ps_context.h"
|
||||||
#endif
|
#endif
|
||||||
|
@ -3553,7 +3553,7 @@ static void HandleFullySplitParameters(const FuncGraphPtr &root) {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) {
|
bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) {
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) {
|
if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -295,7 +295,7 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Windows")
|
||||||
target_link_libraries(_c_dataengine PRIVATE _c_mindrecord ${MINDRECORD_LINK_OBJECT} mindspore::sqlite)
|
target_link_libraries(_c_dataengine PRIVATE _c_mindrecord ${MINDRECORD_LINK_OBJECT} mindspore::sqlite)
|
||||||
else()
|
else()
|
||||||
target_link_libraries(_c_dataengine PRIVATE _c_mindrecord)
|
target_link_libraries(_c_dataengine PRIVATE _c_mindrecord)
|
||||||
if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))
|
if(ENABLE_CPU AND NOT WIN32)
|
||||||
if(${ENABLE_IBVERBS} STREQUAL "ON")
|
if(${ENABLE_IBVERBS} STREQUAL "ON")
|
||||||
target_link_libraries(_c_dataengine PRIVATE ibverbs rdmacm)
|
target_link_libraries(_c_dataengine PRIVATE ibverbs rdmacm)
|
||||||
endif()
|
endif()
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
add_subdirectory(perf EXCLUDE_FROM_ALL)
|
add_subdirectory(perf EXCLUDE_FROM_ALL)
|
||||||
include_directories("${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache")
|
include_directories("${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache")
|
||||||
set(MD_FLATBUFFER_OU "${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache")
|
set(MD_FLATBUFFER_OU "${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache")
|
||||||
ms_build_flatbuffers("de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${MD_FLATBUFFER_OU})
|
set(FBS_FILES de_tensor.fbs)
|
||||||
|
ms_build_flatbuffers(FBS_FILES ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${MD_FLATBUFFER_OU})
|
||||||
|
|
||||||
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||||
|
|
|
@ -43,7 +43,7 @@
|
||||||
#include "vm/transform.h"
|
#include "vm/transform.h"
|
||||||
#include "parse/python_adapter.h"
|
#include "parse/python_adapter.h"
|
||||||
#include "frontend/optimizer/py_pass_manager.h"
|
#include "frontend/optimizer/py_pass_manager.h"
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
#include "ps/parameter_server.h"
|
#include "ps/parameter_server.h"
|
||||||
#include "ps/scheduler.h"
|
#include "ps/scheduler.h"
|
||||||
#include "ps/worker.h"
|
#include "ps/worker.h"
|
||||||
|
@ -606,7 +606,7 @@ bool ExecuteAction(const ResourcePtr &res) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
bool StartPSWorkerAction(const ResourcePtr &res) {
|
bool StartPSWorkerAction(const ResourcePtr &res) {
|
||||||
ps::Worker::GetInstance().Run();
|
ps::Worker::GetInstance().Run();
|
||||||
return true;
|
return true;
|
||||||
|
@ -782,7 +782,7 @@ std::vector<ActionItem> VmPipeline() {
|
||||||
actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction));
|
actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction));
|
||||||
|
|
||||||
actions.emplace_back(std::make_pair("validate", ValidateAction));
|
actions.emplace_back(std::make_pair("validate", ValidateAction));
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
if (ps::PSContext::instance()->is_worker()) {
|
if (ps::PSContext::instance()->is_worker()) {
|
||||||
actions.emplace_back(std::make_pair("worker", StartPSWorkerAction));
|
actions.emplace_back(std::make_pair("worker", StartPSWorkerAction));
|
||||||
}
|
}
|
||||||
|
@ -796,7 +796,7 @@ std::vector<ActionItem> VmPipeline() {
|
||||||
return actions;
|
return actions;
|
||||||
}
|
}
|
||||||
|
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
std::vector<ActionItem> PServerPipeline() {
|
std::vector<ActionItem> PServerPipeline() {
|
||||||
auto actions = CommonPipeline();
|
auto actions = CommonPipeline();
|
||||||
actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));
|
actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));
|
||||||
|
|
|
@ -34,7 +34,7 @@
|
||||||
#else
|
#else
|
||||||
#include "runtime/device/gpu/distribution/collective_fake_init.h"
|
#include "runtime/device/gpu/distribution/collective_fake_init.h"
|
||||||
#endif
|
#endif
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
#include "ps/util.h"
|
#include "ps/util.h"
|
||||||
#endif
|
#endif
|
||||||
#include "ps/ps_context.h"
|
#include "ps/ps_context.h"
|
||||||
|
|
|
@ -42,7 +42,7 @@
|
||||||
#include "pipeline/jit/pipeline_split.h"
|
#include "pipeline/jit/pipeline_split.h"
|
||||||
#include "pipeline/jit/static_analysis/auto_monad.h"
|
#include "pipeline/jit/static_analysis/auto_monad.h"
|
||||||
#include "frontend/optimizer/irpass/gradient_eliminate.h"
|
#include "frontend/optimizer/irpass/gradient_eliminate.h"
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
#include "ps/util.h"
|
#include "ps/util.h"
|
||||||
#include "ps/ps_context.h"
|
#include "ps/ps_context.h"
|
||||||
#endif
|
#endif
|
||||||
|
@ -407,7 +407,7 @@ bool AddRecomputationPass(const ResourcePtr &res) {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AddCacheEmbeddingPass(const ResourcePtr &res) {
|
bool AddCacheEmbeddingPass(const ResourcePtr &res) {
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
if (ps::PSContext::instance()->is_ps_mode()) {
|
if (ps::PSContext::instance()->is_ps_mode()) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,7 +49,7 @@
|
||||||
#include "utils/shape_utils.h"
|
#include "utils/shape_utils.h"
|
||||||
#include "utils/info.h"
|
#include "utils/info.h"
|
||||||
#include "load_mindir/load_model.h"
|
#include "load_mindir/load_model.h"
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
#include "ps/constants.h"
|
#include "ps/constants.h"
|
||||||
#include "ps/util.h"
|
#include "ps/util.h"
|
||||||
#include "ps/worker.h"
|
#include "ps/worker.h"
|
||||||
|
@ -528,7 +528,7 @@ std::vector<ActionItem> GetPipeline(const ResourcePtr &resource, const std::stri
|
||||||
|
|
||||||
std::string backend = MsContext::GetInstance()->backend_policy();
|
std::string backend = MsContext::GetInstance()->backend_policy();
|
||||||
|
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
if (ps::PSContext::instance()->is_server()) {
|
if (ps::PSContext::instance()->is_server()) {
|
||||||
resource->results()[kBackend] = compile::CreateBackend();
|
resource->results()[kBackend] = compile::CreateBackend();
|
||||||
return PServerPipeline();
|
return PServerPipeline();
|
||||||
|
@ -961,7 +961,7 @@ bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t ba
|
||||||
bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size,
|
bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size,
|
||||||
const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
|
const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
|
||||||
const std::vector<int64_t> &input_indexes, bool need_run) {
|
const std::vector<int64_t> &input_indexes, bool need_run) {
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
if ((ps::PSContext::instance()->is_ps_mode()) && (!ps::PSContext::instance()->is_worker())) {
|
if ((ps::PSContext::instance()->is_ps_mode()) && (!ps::PSContext::instance()->is_worker())) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -1027,7 +1027,7 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc
|
||||||
}
|
}
|
||||||
ConfigManager::GetInstance().set_iter_num(size);
|
ConfigManager::GetInstance().set_iter_num(size);
|
||||||
// PS cache does not support loop sink.
|
// PS cache does not support loop sink.
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
||||||
ps::PsDataPrefetch::GetInstance().CreateDataChannel(queue_name, LongToSize(size));
|
ps::PsDataPrefetch::GetInstance().CreateDataChannel(queue_name, LongToSize(size));
|
||||||
ConfigManager::GetInstance().set_iter_num(1);
|
ConfigManager::GetInstance().set_iter_num(1);
|
||||||
|
@ -1150,7 +1150,7 @@ void FinalizeBackend() {
|
||||||
void ClearResAtexit() {
|
void ClearResAtexit() {
|
||||||
MS_LOG(DEBUG) << "Pipeline clear all resource";
|
MS_LOG(DEBUG) << "Pipeline clear all resource";
|
||||||
pynative::ClearPyNativeSession();
|
pynative::ClearPyNativeSession();
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
if (ps::PSContext::instance()->is_ps_mode() && ps::PSContext::instance()->is_worker()) {
|
if (ps::PSContext::instance()->is_ps_mode() && ps::PSContext::instance()->is_worker()) {
|
||||||
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
||||||
ps::ps_cache_instance.Finalize();
|
ps::ps_cache_instance.Finalize();
|
||||||
|
|
|
@ -1,6 +1,13 @@
|
||||||
file(GLOB_RECURSE _PS_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
file(GLOB_RECURSE _PS_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||||
|
|
||||||
if(NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)))
|
set(SERVER_FLATBUFFER_OUTPUT "${CMAKE_BINARY_DIR}/schema")
|
||||||
|
set(FBS_FILES
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/../../schema/cipher.fbs
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/../../schema/fl_job.fbs
|
||||||
|
)
|
||||||
|
ms_build_flatbuffers(FBS_FILES ${CMAKE_CURRENT_SOURCE_DIR}../../schema generated_fbs_files ${SERVER_FLATBUFFER_OUTPUT})
|
||||||
|
|
||||||
|
if(NOT ENABLE_CPU OR WIN32)
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "optimizer_info_builder.cc")
|
list(REMOVE_ITEM _PS_SRC_FILES "optimizer_info_builder.cc")
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "optimizer_info.cc")
|
list(REMOVE_ITEM _PS_SRC_FILES "optimizer_info.cc")
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "scheduler.cc")
|
list(REMOVE_ITEM _PS_SRC_FILES "scheduler.cc")
|
||||||
|
@ -12,11 +19,6 @@ if(NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)))
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_client.cc")
|
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_client.cc")
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_message_handler.cc")
|
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_message_handler.cc")
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_server.cc")
|
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_server.cc")
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/communicator_base.cc")
|
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/http_communicator.cc")
|
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_communicator.cc")
|
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/http_msg_handler.cc")
|
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_msg_handler.cc")
|
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "core/node.cc")
|
list(REMOVE_ITEM _PS_SRC_FILES "core/node.cc")
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "core/node_manager.cc")
|
list(REMOVE_ITEM _PS_SRC_FILES "core/node_manager.cc")
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc")
|
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc")
|
||||||
|
@ -39,18 +41,32 @@ if(NOT ENABLE_GPU)
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/gpu/gpu_ps_cache.cc")
|
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/gpu/gpu_ps_cache.cc")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(WIN32 OR NOT ENABLE_CPU)
|
if(NOT ENABLE_CPU OR WIN32)
|
||||||
|
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/communicator_base.cc")
|
||||||
|
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/http_communicator.cc")
|
||||||
|
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_communicator.cc")
|
||||||
|
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/http_msg_handler.cc")
|
||||||
|
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_msg_handler.cc")
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/apply_momentum_kernel.cc")
|
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/apply_momentum_kernel.cc")
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/aggregation_kernel_factory.cc")
|
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/aggregation_kernel_factory.cc")
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/dense_grad_accum_kernel.cc")
|
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/dense_grad_accum_kernel.cc")
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/optimizer_kernel_factory.cc")
|
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/optimizer_kernel_factory.cc")
|
||||||
|
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/round_kernel_factory.cc")
|
||||||
|
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/round_kernel.cc")
|
||||||
|
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/start_fl_job_kernel.cc")
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/params_info.cc")
|
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/params_info.cc")
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "server/consistent_hash_ring.cc")
|
list(REMOVE_ITEM _PS_SRC_FILES "server/consistent_hash_ring.cc")
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "server/iteration_timer.cc")
|
list(REMOVE_ITEM _PS_SRC_FILES "server/iteration_timer.cc")
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "server/local_meta_storage.cc")
|
list(REMOVE_ITEM _PS_SRC_FILES "server/local_meta_store.cc")
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "server/memory_register.cc")
|
list(REMOVE_ITEM _PS_SRC_FILES "server/memory_register.cc")
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "server/parameter_aggregator.cc")
|
list(REMOVE_ITEM _PS_SRC_FILES "server/parameter_aggregator.cc")
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "server/executor.cc")
|
list(REMOVE_ITEM _PS_SRC_FILES "server/executor.cc")
|
||||||
|
list(REMOVE_ITEM _PS_SRC_FILES "server/collective_ops_impl.cc")
|
||||||
|
list(REMOVE_ITEM _PS_SRC_FILES "server/distributed_count_service.cc")
|
||||||
|
list(REMOVE_ITEM _PS_SRC_FILES "server/distributed_metadata_store.cc")
|
||||||
|
list(REMOVE_ITEM _PS_SRC_FILES "server/iteration.cc")
|
||||||
|
list(REMOVE_ITEM _PS_SRC_FILES "server/model_store.cc")
|
||||||
|
list(REMOVE_ITEM _PS_SRC_FILES "server/round.cc")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_prefetch.cc")
|
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_prefetch.cc")
|
||||||
|
@ -59,3 +75,5 @@ add_subdirectory(ps_cache)
|
||||||
|
|
||||||
set_property(SOURCE ${_PS_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS)
|
set_property(SOURCE ${_PS_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS)
|
||||||
add_library(_mindspore_ps_obj OBJECT ${_PS_SRC_FILES})
|
add_library(_mindspore_ps_obj OBJECT ${_PS_SRC_FILES})
|
||||||
|
add_dependencies(_mindspore_ps_obj generated_fbs_files)
|
||||||
|
target_link_libraries(_mindspore_ps_obj mindspore::flatbuffers)
|
||||||
|
|
|
@ -34,13 +34,26 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace ps {
|
||||||
namespace core {
|
namespace core {
|
||||||
enum class TcpUserCommand { kPush, kPull, kCount, kReachThreshold, kResetCount, kGetValue, kPutValue, kCounterEvent };
|
enum class TcpUserCommand {
|
||||||
|
kPush,
|
||||||
|
kPull,
|
||||||
|
kCount,
|
||||||
|
kReachThreshold,
|
||||||
|
kResetCount,
|
||||||
|
kGetMetadata,
|
||||||
|
kUpdateMetadata,
|
||||||
|
kCounterEvent
|
||||||
|
};
|
||||||
|
|
||||||
const std::unordered_map<TcpUserCommand, std::string> kUserCommandToMsgType = {
|
const std::unordered_map<TcpUserCommand, std::string> kUserCommandToMsgType = {
|
||||||
{TcpUserCommand::kPush, "push"}, {TcpUserCommand::kPull, "pull"},
|
{TcpUserCommand::kPush, "push"},
|
||||||
{TcpUserCommand::kCount, "count"}, {TcpUserCommand::kReachThreshold, "reachThreshold"},
|
{TcpUserCommand::kPull, "pull"},
|
||||||
{TcpUserCommand::kResetCount, "resetCnt"}, {TcpUserCommand::kGetValue, "getValue"},
|
{TcpUserCommand::kCount, "count"},
|
||||||
{TcpUserCommand::kPutValue, "putValue"}, {TcpUserCommand::kCounterEvent, "counterEvent"},
|
{TcpUserCommand::kReachThreshold, "countReachThreshold"},
|
||||||
|
{TcpUserCommand::kResetCount, "resetCnt"},
|
||||||
|
{TcpUserCommand::kGetMetadata, "getMetadata"},
|
||||||
|
{TcpUserCommand::kUpdateMetadata, "updateMetadata"},
|
||||||
|
{TcpUserCommand::kCounterEvent, "counterEvent"},
|
||||||
};
|
};
|
||||||
|
|
||||||
class TcpCommunicator : public CommunicatorBase {
|
class TcpCommunicator : public CommunicatorBase {
|
||||||
|
|
|
@ -0,0 +1,155 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
syntax = "proto3";
|
||||||
|
package mindspore.ps;
|
||||||
|
|
||||||
|
message CollectiveData {
|
||||||
|
bytes data = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message CountRequest {
|
||||||
|
string name = 1;
|
||||||
|
string id = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message CountResponse {
|
||||||
|
bool result = 1;
|
||||||
|
string reason = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message CountReachThresholdRequest {
|
||||||
|
string name = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message CountReachThresholdResponse {
|
||||||
|
bool is_enough = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ResetCounterRequest {
|
||||||
|
string name = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message UpdateMetadataRequest {
|
||||||
|
string name = 1;
|
||||||
|
bytes value = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message GetMetadataRequest {
|
||||||
|
string name = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message GetMetadataResponse {
|
||||||
|
bytes value = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
enum CounterEventType {
|
||||||
|
FIRST_CNT = 0;
|
||||||
|
LAST_CNT = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message CounterEvent {
|
||||||
|
CounterEventType type = 1;
|
||||||
|
string name = 2;
|
||||||
|
bytes data = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message FLId {
|
||||||
|
string fl_id = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message UpdateModelClientList {
|
||||||
|
repeated string fl_id = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message DeviceMeta {
|
||||||
|
string fl_name = 1;
|
||||||
|
string fl_id = 2;
|
||||||
|
uint64 data_size = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message FLIdToDeviceMeta {
|
||||||
|
map<string, DeviceMeta> fl_id_to_meta = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message UpdateModelThreshold {
|
||||||
|
uint64 threshold = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ClientShares {
|
||||||
|
map<string, SharesPb> client_secret_shares = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message PairClientShares {
|
||||||
|
string fl_id = 1;
|
||||||
|
SharesPb client_shares = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ClientKeys {
|
||||||
|
map<string, KeysPb> client_keys = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ClientNoises {
|
||||||
|
OneClientNoises one_client_noises = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message PairClientKeys {
|
||||||
|
string fl_id = 1;
|
||||||
|
KeysPb client_keys = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message OneClientNoises {
|
||||||
|
repeated float noise = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ClientShareStr {
|
||||||
|
string fl_id = 1;
|
||||||
|
bytes share = 2; // todo: verify the correctness
|
||||||
|
int32 index = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message SharesPb {
|
||||||
|
repeated ClientShareStr clientsharestrs = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message KeysPb {
|
||||||
|
repeated bytes key = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message PBMetadata {
|
||||||
|
oneof value {
|
||||||
|
DeviceMeta device_meta = 1;
|
||||||
|
FLIdToDeviceMeta device_metas = 2;
|
||||||
|
|
||||||
|
FLId fl_id = 3;
|
||||||
|
UpdateModelClientList client_list = 4;
|
||||||
|
|
||||||
|
UpdateModelThreshold update_model_threshold = 5;
|
||||||
|
|
||||||
|
PairClientShares pair_client_shares = 6;
|
||||||
|
ClientShares client_shares = 7;
|
||||||
|
|
||||||
|
PairClientKeys pair_client_keys = 8;
|
||||||
|
ClientKeys client_keys = 9;
|
||||||
|
|
||||||
|
OneClientNoises one_client_noises = 10;
|
||||||
|
ClientNoises client_noises = 11;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
message PBMetadataWithName {
|
||||||
|
string name = 1;
|
||||||
|
PBMetadata metadata = 2;
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))
|
if(ENABLE_CPU AND NOT WIN32)
|
||||||
file(GLOB_RECURSE _PS_CACHE_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ps_data/*.cc")
|
file(GLOB_RECURSE _PS_CACHE_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ps_data/*.cc")
|
||||||
set_property(SOURCE ${_PS_CACHE_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS)
|
set_property(SOURCE ${_PS_CACHE_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS)
|
||||||
add_library(ps_cache SHARED ${_PS_CACHE_SRC_FILES})
|
add_library(ps_cache SHARED ${_PS_CACHE_SRC_FILES})
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
#include "utils/ms_utils.h"
|
#include "utils/ms_utils.h"
|
||||||
#include "backend/kernel_compiler/kernel.h"
|
#include "backend/kernel_compiler/kernel.h"
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
#include "ps/ps_cache/ps_cache_manager.h"
|
#include "ps/ps_cache/ps_cache_manager.h"
|
||||||
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
|
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
|
||||||
#endif
|
#endif
|
||||||
|
@ -68,7 +68,7 @@ void PSContext::Reset() {
|
||||||
is_worker_ = false;
|
is_worker_ = false;
|
||||||
is_pserver_ = false;
|
is_pserver_ = false;
|
||||||
is_sched_ = false;
|
is_sched_ = false;
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
||||||
ps_cache_instance.Finalize();
|
ps_cache_instance.Finalize();
|
||||||
set_cache_enable(false);
|
set_cache_enable(false);
|
||||||
|
@ -108,46 +108,62 @@ int PSContext::ps_rank_id() const { return rank_id_; }
|
||||||
|
|
||||||
void PSContext::InsertHashTableSize(const std::string ¶m_name, size_t cache_vocab_size, size_t embedding_size,
|
void PSContext::InsertHashTableSize(const std::string ¶m_name, size_t cache_vocab_size, size_t embedding_size,
|
||||||
size_t vocab_size) const {
|
size_t vocab_size) const {
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
ps_cache_instance.InsertHashTableSize(param_name, cache_vocab_size, embedding_size, vocab_size);
|
ps_cache_instance.InsertHashTableSize(param_name, cache_vocab_size, embedding_size, vocab_size);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
void PSContext::ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name,
|
void PSContext::ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name,
|
||||||
size_t cache_vocab_size, size_t embedding_size) const {
|
size_t cache_vocab_size, size_t embedding_size) const {
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
ps_cache_instance.ReInsertHashTableSize(new_param_name, cur_param_name, cache_vocab_size, embedding_size);
|
ps_cache_instance.ReInsertHashTableSize(new_param_name, cur_param_name, cache_vocab_size, embedding_size);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
void PSContext::InsertWeightInitInfo(const std::string ¶m_name, size_t global_seed, size_t op_seed) const {
|
void PSContext::InsertWeightInitInfo(const std::string ¶m_name, size_t global_seed, size_t op_seed) const {
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
ps_cache_instance.InsertWeightInitInfo(param_name, global_seed, op_seed);
|
ps_cache_instance.InsertWeightInitInfo(param_name, global_seed, op_seed);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
void PSContext::InsertAccumuInitInfo(const std::string ¶m_name, float init_val) const {
|
void PSContext::InsertAccumuInitInfo(const std::string ¶m_name, float init_val) const {
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
ps_cache_instance.InsertAccumuInitInfo(param_name, init_val);
|
ps_cache_instance.InsertAccumuInitInfo(param_name, init_val);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
void PSContext::CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) const {
|
void PSContext::CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) const {
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
ps_cache_instance.CloneHashTable(dest_param_name, src_param_name);
|
ps_cache_instance.CloneHashTable(dest_param_name, src_param_name);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
void PSContext::set_cache_enable(bool cache_enable) const {
|
void PSContext::set_cache_enable(bool cache_enable) const {
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
PsDataPrefetch::GetInstance().set_cache_enable(cache_enable);
|
PsDataPrefetch::GetInstance().set_cache_enable(cache_enable);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
void PSContext::set_rank_id(int rank_id) const {
|
void PSContext::set_rank_id(int rank_id) const {
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
ps_cache_instance.set_rank_id(rank_id);
|
ps_cache_instance.set_rank_id(rank_id);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void PSContext::set_fl_name(const std::string &fl_name) { fl_name_ = fl_name; }
|
||||||
|
|
||||||
|
const std::string &PSContext::fl_name() const { return fl_name_; }
|
||||||
|
|
||||||
|
void PSContext::set_fl_iteration_num(uint64_t fl_iteration_num) { fl_iteration_num_ = fl_iteration_num; }
|
||||||
|
|
||||||
|
uint64_t PSContext::fl_iteration_num() const { return fl_iteration_num_; }
|
||||||
|
|
||||||
|
void PSContext::set_client_epoch_num(uint64_t client_epoch_num) { client_epoch_num_ = client_epoch_num; }
|
||||||
|
|
||||||
|
uint64_t PSContext::client_epoch_num() const { return client_epoch_num_; }
|
||||||
|
|
||||||
|
void PSContext::set_client_batch_size(uint64_t client_batch_size) { client_batch_size_ = client_batch_size; }
|
||||||
|
|
||||||
|
uint64_t PSContext::client_batch_size() const { return client_batch_size_; }
|
||||||
} // namespace ps
|
} // namespace ps
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -60,6 +60,19 @@ class PSContext {
|
||||||
void set_cache_enable(bool cache_enable) const;
|
void set_cache_enable(bool cache_enable) const;
|
||||||
void set_rank_id(int rank_id) const;
|
void set_rank_id(int rank_id) const;
|
||||||
|
|
||||||
|
// Setter and getter for federated learning.
|
||||||
|
void set_fl_name(const std::string &fl_name);
|
||||||
|
const std::string &fl_name() const;
|
||||||
|
|
||||||
|
void set_fl_iteration_num(uint64_t fl_iteration_num);
|
||||||
|
uint64_t fl_iteration_num() const;
|
||||||
|
|
||||||
|
void set_client_epoch_num(uint64_t client_epoch_num);
|
||||||
|
uint64_t client_epoch_num() const;
|
||||||
|
|
||||||
|
void set_client_batch_size(uint64_t client_batch_size);
|
||||||
|
uint64_t client_batch_size() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
PSContext()
|
PSContext()
|
||||||
: ps_enabled_(false),
|
: ps_enabled_(false),
|
||||||
|
@ -80,6 +93,12 @@ class PSContext {
|
||||||
uint32_t server_num_;
|
uint32_t server_num_;
|
||||||
std::string scheduler_host_;
|
std::string scheduler_host_;
|
||||||
uint16_t scheduler_port_;
|
uint16_t scheduler_port_;
|
||||||
|
|
||||||
|
// Members for federated learning.
|
||||||
|
std::string fl_name_;
|
||||||
|
uint64_t fl_iteration_num_;
|
||||||
|
uint64_t client_epoch_num_;
|
||||||
|
uint64_t client_batch_size_;
|
||||||
};
|
};
|
||||||
} // namespace ps
|
} // namespace ps
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -0,0 +1,223 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "ps/server/collective_ops_impl.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ps {
|
||||||
|
namespace server {
|
||||||
|
void CollectiveOpsImpl::Initialize(const std::shared_ptr<core::ServerNode> &server_node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(server_node);
|
||||||
|
server_node_ = server_node;
|
||||||
|
local_rank_ = server_node_->rank_id();
|
||||||
|
server_num_ = PSContext::instance()->initial_server_num();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size_t count) {
|
||||||
|
int ret = memcpy_s(recvbuff, count * sizeof(T), sendbuff, count * sizeof(T));
|
||||||
|
if (ret != 0) {
|
||||||
|
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t rank_size = server_num_;
|
||||||
|
uint32_t local_rank_ = server_node_->rank_id();
|
||||||
|
size_t chunk_size = count / rank_size;
|
||||||
|
size_t remainder_size = count % rank_size;
|
||||||
|
std::vector<size_t> chunk_sizes(rank_size, chunk_size);
|
||||||
|
// The rest of the data should be assigned to each chunk.
|
||||||
|
for (size_t i = 0; i < remainder_size; i++) {
|
||||||
|
chunk_sizes[i]++;
|
||||||
|
}
|
||||||
|
// Store offsets to get every data chunk's address.
|
||||||
|
std::vector<size_t> chunk_offset;
|
||||||
|
for (size_t i = 0; i < rank_size; i++) {
|
||||||
|
size_t ofs =
|
||||||
|
std::accumulate(chunk_sizes.begin(), chunk_sizes.begin() + i, static_cast<size_t>(0), std::plus<size_t>());
|
||||||
|
chunk_offset.push_back(ofs);
|
||||||
|
}
|
||||||
|
|
||||||
|
T *output_buff = reinterpret_cast<T *>(recvbuff);
|
||||||
|
uint32_t send_to_rank = (local_rank_ + 1) % rank_size;
|
||||||
|
uint32_t recv_from_rank = (local_rank_ - 1 + rank_size) % rank_size;
|
||||||
|
MS_LOG(DEBUG) << "AllReduce count:" << count << ", rank_size:" << rank_size << ", local_rank_:" << local_rank_
|
||||||
|
<< ", chunk_size:" << chunk_size << ", remainder_size:" << remainder_size
|
||||||
|
<< ", chunk_sizes:" << chunk_sizes << ", send_to_rank:" << send_to_rank
|
||||||
|
<< ", recv_from_rank:" << recv_from_rank;
|
||||||
|
|
||||||
|
// Ring ReduceScatter.
|
||||||
|
MS_LOG(DEBUG) << "Start Ring ReduceScatter.";
|
||||||
|
std::unique_ptr<T[]> tmp_recv_chunk = std::make_unique<T[]>(chunk_sizes[0]);
|
||||||
|
for (size_t i = 0; i < rank_size - 1; i++) {
|
||||||
|
// Step 1: Async send data to next rank.
|
||||||
|
size_t send_chunk_index = (local_rank_ - i + rank_size) % rank_size;
|
||||||
|
T *send_chunk = output_buff + chunk_offset[send_chunk_index];
|
||||||
|
auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, send_to_rank, send_chunk,
|
||||||
|
chunk_sizes[send_chunk_index] * sizeof(T));
|
||||||
|
// Step 2: Async receive data to next rank and wait until it's done.
|
||||||
|
size_t recv_chunk_index = (local_rank_ - i - 1 + rank_size) % rank_size;
|
||||||
|
T *recv_chunk = output_buff + chunk_offset[recv_chunk_index];
|
||||||
|
MS_LOG(DEBUG) << "Ring ReduceScatter send_to_rank:" << send_to_rank << ", recv_from_rank:" << recv_from_rank
|
||||||
|
<< ", send count:" << chunk_sizes[send_chunk_index]
|
||||||
|
<< ", recv count:" << chunk_sizes[recv_chunk_index] << ", iteration:" << i;
|
||||||
|
|
||||||
|
std::shared_ptr<std::vector<unsigned char>> recv_str;
|
||||||
|
auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, recv_from_rank, &recv_str);
|
||||||
|
if (!server_node_->CollectiveWait(recv_req_id, 1)) {
|
||||||
|
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
memcpy_s(tmp_recv_chunk.get(), chunk_sizes[recv_chunk_index] * sizeof(T), recv_str->data(), recv_str->size());
|
||||||
|
|
||||||
|
// Step 3: Reduce the data so we can overlap the time cost of send.
|
||||||
|
for (size_t j = 0; j < chunk_sizes[recv_chunk_index]; j++) {
|
||||||
|
recv_chunk[j] += tmp_recv_chunk[j];
|
||||||
|
}
|
||||||
|
// Step 4: Wait until send is done.
|
||||||
|
if (!server_node_->Wait(send_req_id, 1)) {
|
||||||
|
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "End Ring ReduceScatter.";
|
||||||
|
|
||||||
|
// Ring AllGather.
|
||||||
|
MS_LOG(DEBUG) << "Start Ring AllGather.";
|
||||||
|
for (size_t i = 0; i < rank_size - 1; i++) {
|
||||||
|
size_t send_chunk_index = (local_rank_ - i + 1 + rank_size) % rank_size;
|
||||||
|
T *send_chunk = output_buff + chunk_offset[send_chunk_index];
|
||||||
|
auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, send_to_rank, send_chunk,
|
||||||
|
chunk_sizes[send_chunk_index] * sizeof(T));
|
||||||
|
size_t recv_chunk_index = (local_rank_ - i + rank_size) % rank_size;
|
||||||
|
T *recv_chunk = output_buff + chunk_offset[recv_chunk_index];
|
||||||
|
MS_LOG(DEBUG) << "Ring AllGather send_to_rank:" << send_to_rank << ", recv_from_rank:" << recv_from_rank
|
||||||
|
<< ", send count:" << chunk_sizes[send_chunk_index]
|
||||||
|
<< ", recv count:" << chunk_sizes[recv_chunk_index] << ", iteration:" << i;
|
||||||
|
|
||||||
|
std::shared_ptr<std::vector<unsigned char>> recv_str;
|
||||||
|
auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, recv_from_rank, &recv_str);
|
||||||
|
|
||||||
|
if (!server_node_->CollectiveWait(recv_req_id, 1)) {
|
||||||
|
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
memcpy_s(recv_chunk, chunk_sizes[recv_chunk_index] * sizeof(T), recv_str->data(), recv_str->size());
|
||||||
|
if (!server_node_->Wait(send_req_id, 1)) {
|
||||||
|
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "End Ring AllGather.";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *recvbuff, size_t count) {
|
||||||
|
uint32_t rank_size = server_num_;
|
||||||
|
uint32_t local_rank_ = server_node_->rank_id();
|
||||||
|
MS_LOG(DEBUG) << "Reduce Broadcast AllReduce rank_size:" << rank_size << ", local_rank_:" << local_rank_
|
||||||
|
<< ", count:" << count;
|
||||||
|
int ret = memcpy_s(recvbuff, count * sizeof(T), sendbuff, count * sizeof(T));
|
||||||
|
if (ret != 0) {
|
||||||
|
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
T *output_buff = reinterpret_cast<T *>(recvbuff);
|
||||||
|
// Reduce data to rank 0 process.
|
||||||
|
MS_LOG(DEBUG) << "Start Reduce to rank 0 process.";
|
||||||
|
if (local_rank_ == 0) {
|
||||||
|
std::unique_ptr<T[]> tmp_recv_buff = std::make_unique<T[]>(count);
|
||||||
|
for (uint32_t i = 1; i < rank_size; i++) {
|
||||||
|
std::shared_ptr<std::vector<unsigned char>> recv_str;
|
||||||
|
MS_LOG(DEBUG) << "Reduce rank 0 receive from rank " << i;
|
||||||
|
auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, i, &recv_str);
|
||||||
|
if (!server_node_->CollectiveWait(recv_req_id, 1)) {
|
||||||
|
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
memcpy_s(tmp_recv_buff.get(), count * sizeof(T), recv_str->data(), recv_str->size());
|
||||||
|
for (size_t j = 0; j < count; j++) {
|
||||||
|
output_buff[j] += tmp_recv_buff[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
MS_LOG(DEBUG) << "Reduce send data to rank 0 process.";
|
||||||
|
auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, 0, sendbuff, count * sizeof(T));
|
||||||
|
if (!server_node_->Wait(send_req_id, 1)) {
|
||||||
|
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "End Reduce.";
|
||||||
|
|
||||||
|
// Broadcast data to not 0 rank process.
|
||||||
|
MS_LOG(DEBUG) << "Start broadcast from rank 0 to other processes.";
|
||||||
|
if (local_rank_ == 0) {
|
||||||
|
for (uint32_t i = 1; i < rank_size; i++) {
|
||||||
|
MS_LOG(DEBUG) << "Broadcast data to process " << i;
|
||||||
|
auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, i, output_buff, count * sizeof(T));
|
||||||
|
if (!server_node_->Wait(send_req_id, 1)) {
|
||||||
|
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
MS_LOG(DEBUG) << "Broadcast receive from rank 0.";
|
||||||
|
std::shared_ptr<std::vector<unsigned char>> recv_str;
|
||||||
|
auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, 0, &recv_str);
|
||||||
|
if (!server_node_->CollectiveWait(recv_req_id, 1)) {
|
||||||
|
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
memcpy_s(output_buff, count * sizeof(T), recv_str->data(), recv_str->size());
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "End broadcast.";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
bool CollectiveOpsImpl::AllReduce(const void *sendbuff, void *recvbuff, size_t count) {
|
||||||
|
// The collective communication API does not support calling Send and Recv concurrently with multiple threads;
|
||||||
|
std::unique_lock<std::mutex> lock(mtx_);
|
||||||
|
if (sendbuff == nullptr || recvbuff == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "AllReduce sendbuff or recvbuff is nullptr.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t rank_size = server_num_;
|
||||||
|
if (count >= rank_size) {
|
||||||
|
return RingAllReduce<T>(sendbuff, recvbuff, count);
|
||||||
|
} else {
|
||||||
|
return ReduceBroadcastAllReduce<T>(sendbuff, recvbuff, count);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template bool CollectiveOpsImpl::RingAllReduce<float>(const void *sendbuff, void *recvbuff, size_t count);
|
||||||
|
template bool CollectiveOpsImpl::RingAllReduce<size_t>(const void *sendbuff, void *recvbuff, size_t count);
|
||||||
|
template bool CollectiveOpsImpl::RingAllReduce<int>(const void *sendbuff, void *recvbuff, size_t count);
|
||||||
|
|
||||||
|
template bool CollectiveOpsImpl::ReduceBroadcastAllReduce<float>(const void *sendbuff, void *recvbuff, size_t count);
|
||||||
|
template bool CollectiveOpsImpl::ReduceBroadcastAllReduce<size_t>(const void *sendbuff, void *recvbuff, size_t count);
|
||||||
|
template bool CollectiveOpsImpl::ReduceBroadcastAllReduce<int>(const void *sendbuff, void *recvbuff, size_t count);
|
||||||
|
|
||||||
|
template bool CollectiveOpsImpl::AllReduce<float>(const void *sendbuff, void *recvbuff, size_t count);
|
||||||
|
template bool CollectiveOpsImpl::AllReduce<size_t>(const void *sendbuff, void *recvbuff, size_t count);
|
||||||
|
template bool CollectiveOpsImpl::AllReduce<int>(const void *sendbuff, void *recvbuff, size_t count);
|
||||||
|
} // namespace server
|
||||||
|
} // namespace ps
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,71 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_PS_SERVER_COLLECTIVE_OPS_IMPL_H_
|
||||||
|
#define MINDSPORE_CCSRC_PS_SERVER_COLLECTIVE_OPS_IMPL_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <functional>
|
||||||
|
#include "proto/ps.pb.h"
|
||||||
|
#include "ps/ps_context.h"
|
||||||
|
#include "ps/core/server_node.h"
|
||||||
|
#include "ps/server/common.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ps {
|
||||||
|
namespace server {
|
||||||
|
// CollectiveOpsImpl is the collective communication API of the server.
|
||||||
|
// For now, it implements two AllReduce algorithms: RingAllReduce and BroadcastAllReduce. Elastic AllReduce is also
|
||||||
|
// supported for the elastic scaling feature of the server.
|
||||||
|
class CollectiveOpsImpl {
|
||||||
|
public:
|
||||||
|
static CollectiveOpsImpl &GetInstance() {
|
||||||
|
static CollectiveOpsImpl instance;
|
||||||
|
return instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Initialize(const std::shared_ptr<core::ServerNode> &server_node);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
bool AllReduce(const void *sendbuff, void *recvbuff, size_t count);
|
||||||
|
|
||||||
|
private:
|
||||||
|
CollectiveOpsImpl() = default;
|
||||||
|
~CollectiveOpsImpl() = default;
|
||||||
|
CollectiveOpsImpl(const CollectiveOpsImpl &) = delete;
|
||||||
|
CollectiveOpsImpl &operator=(const CollectiveOpsImpl &) = delete;
|
||||||
|
|
||||||
|
// Implementation of RingAllReduce.
|
||||||
|
template <typename T>
|
||||||
|
bool RingAllReduce(const void *sendbuff, void *recvbuff, size_t count);
|
||||||
|
|
||||||
|
// Implementation of BroadcastAllReduce.
|
||||||
|
template <typename T>
|
||||||
|
bool ReduceBroadcastAllReduce(const void *sendbuff, void *recvbuff, size_t count);
|
||||||
|
|
||||||
|
std::shared_ptr<core::ServerNode> server_node_;
|
||||||
|
uint32_t local_rank_;
|
||||||
|
uint32_t server_num_;
|
||||||
|
|
||||||
|
// The mutex to ensure that collective communication is threadsafe.
|
||||||
|
std::mutex mtx_;
|
||||||
|
};
|
||||||
|
} // namespace server
|
||||||
|
} // namespace ps
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_PS_SERVER_COLLECTIVE_OPS_IMPL_H_
|
|
@ -24,13 +24,17 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include "proto/ps.pb.h"
|
#include "proto/ps.pb.h"
|
||||||
|
#include "proto/fl.pb.h"
|
||||||
#include "ir/anf.h"
|
#include "ir/anf.h"
|
||||||
#include "utils/utils.h"
|
#include "utils/utils.h"
|
||||||
#include "ir/dtype/type_id.h"
|
#include "ir/dtype/type_id.h"
|
||||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||||
|
#include "schema/fl_job_generated.h"
|
||||||
|
#include "schema/cipher_generated.h"
|
||||||
#include "ps/ps_context.h"
|
#include "ps/ps_context.h"
|
||||||
#include "ps/core/communicator/http_message_handler.h"
|
#include "ps/core/communicator/http_message_handler.h"
|
||||||
#include "ps/core/communicator/tcp_server.h"
|
#include "ps/core/communicator/tcp_server.h"
|
||||||
|
#include "ps/core/communicator/message_handler.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace ps {
|
||||||
|
@ -40,13 +44,15 @@ enum ServerMode { PARAMETER_SERVER = 0, FL_SERVER };
|
||||||
enum CommType { HTTP = 0, TCP };
|
enum CommType { HTTP = 0, TCP };
|
||||||
enum AggregationType { FedAvg = 0, FedAdam, FedAdagarg, FedMeta, qffl, DenseGradAccum, SparseGradAccum };
|
enum AggregationType { FedAvg = 0, FedAdam, FedAdagarg, FedMeta, qffl, DenseGradAccum, SparseGradAccum };
|
||||||
|
|
||||||
using kernel::Address;
|
using mindspore::kernel::Address;
|
||||||
using kernel::AddressPtr;
|
using mindspore::kernel::AddressPtr;
|
||||||
using kernel::CPUKernel;
|
using mindspore::kernel::CPUKernel;
|
||||||
|
using FBBuilder = flatbuffers::FlatBufferBuilder;
|
||||||
using TimeOutCb = std::function<void(void)>;
|
using TimeOutCb = std::function<void(void)>;
|
||||||
using StopTimerCb = std::function<void(void)>;
|
using StopTimerCb = std::function<void(void)>;
|
||||||
using FinishIterCb = std::function<void(void)>;
|
using FinishIterCb = std::function<void(void)>;
|
||||||
using FinalizeCb = std::function<void(void)>;
|
using FinalizeCb = std::function<void(void)>;
|
||||||
|
using MessageCallback = std::function<void(const std::shared_ptr<core::MessageHandler> &)>;
|
||||||
|
|
||||||
// Information about whether server kernel will reuse kernel node memory from the front end.
|
// Information about whether server kernel will reuse kernel node memory from the front end.
|
||||||
// Key refers to the server kernel's parameter name, like "weights", "grad", "learning_rate".
|
// Key refers to the server kernel's parameter name, like "weights", "grad", "learning_rate".
|
||||||
|
|
|
@ -0,0 +1,298 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "ps/server/distributed_count_service.h"
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ps {
|
||||||
|
namespace server {
|
||||||
|
void DistributedCountService::Initialize(const std::shared_ptr<core::ServerNode> &server_node,
|
||||||
|
uint32_t counting_server_rank) {
|
||||||
|
server_node_ = server_node;
|
||||||
|
MS_EXCEPTION_IF_NULL(server_node_);
|
||||||
|
|
||||||
|
communicator_ =
|
||||||
|
std::dynamic_pointer_cast<core::TcpCommunicator>(server_node_->GetOrCreateTcpComm("", 0, 0, 0, nullptr));
|
||||||
|
MS_EXCEPTION_IF_NULL(communicator_);
|
||||||
|
|
||||||
|
local_rank_ = server_node_->rank_id();
|
||||||
|
server_num_ = PSContext::instance()->initial_server_num();
|
||||||
|
counting_server_rank_ = counting_server_rank;
|
||||||
|
RegisterCallback();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DistributedCountService::RegisterCounter(const std::string &name, size_t global_threshold_count,
|
||||||
|
const CounterHandlers &counter_handlers) {
|
||||||
|
if (!counter_handlers.first_count_handler || !counter_handlers.last_count_handler) {
|
||||||
|
MS_LOG(EXCEPTION) << "First count handler or last count handler is not set.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (global_threshold_count_.count(name) != 0) {
|
||||||
|
MS_LOG(ERROR) << "Counter for " << name << " is already set.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "Rank " << local_rank_ << " register counter for " << name << " count:" << global_threshold_count;
|
||||||
|
// If the server is the leader server, it needs to set the counter handlers and do the real counting.
|
||||||
|
if (local_rank_ == counting_server_rank_) {
|
||||||
|
global_current_count_[name] = {};
|
||||||
|
global_threshold_count_[name] = global_threshold_count;
|
||||||
|
mutex_[name];
|
||||||
|
}
|
||||||
|
counter_handlers_[name] = counter_handlers;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool DistributedCountService::Count(const std::string &name, const std::string &id) {
|
||||||
|
MS_LOG(INFO) << "Rank " << local_rank_ << " reports count for " << name << " of " << id;
|
||||||
|
if (local_rank_ == counting_server_rank_) {
|
||||||
|
if (global_threshold_count_.count(name) == 0) {
|
||||||
|
MS_LOG(ERROR) << "Counter for " << name << " is not registered.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_lock<std::mutex> lock(mutex_[name]);
|
||||||
|
if (global_current_count_[name].size() >= global_threshold_count_[name]) {
|
||||||
|
MS_LOG(ERROR) << "Count for " << name << " is already enough. Threshold count is "
|
||||||
|
<< global_threshold_count_[name];
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "Leader server increase count for " << name << " of " << id;
|
||||||
|
global_current_count_[name].insert(id);
|
||||||
|
TriggerCounterEvent(name);
|
||||||
|
} else {
|
||||||
|
// If this server is a follower server, it needs to send CountRequest to the leader server.
|
||||||
|
CountRequest report_count_req;
|
||||||
|
report_count_req.set_name(name);
|
||||||
|
report_count_req.set_id(id);
|
||||||
|
|
||||||
|
std::shared_ptr<std::vector<unsigned char>> report_cnt_rsp_msg = nullptr;
|
||||||
|
if (!communicator_->SendPbRequest(report_count_req, counting_server_rank_, core::TcpUserCommand::kCount,
|
||||||
|
&report_cnt_rsp_msg)) {
|
||||||
|
MS_LOG(ERROR) << "Sending reporting count message to leader server failed for " << name;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
CountResponse count_rsp;
|
||||||
|
count_rsp.ParseFromArray(report_cnt_rsp_msg->data(), report_cnt_rsp_msg->size());
|
||||||
|
if (!count_rsp.result()) {
|
||||||
|
MS_LOG(ERROR) << "Reporting count failed:" << count_rsp.reason();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool DistributedCountService::CountReachThreshold(const std::string &name) {
|
||||||
|
MS_LOG(INFO) << "Rank " << local_rank_ << " query whether count reaches threshold for " << name;
|
||||||
|
if (local_rank_ == counting_server_rank_) {
|
||||||
|
if (global_threshold_count_.count(name) == 0) {
|
||||||
|
MS_LOG(ERROR) << "Counter for " << name << " is not set.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_lock<std::mutex> lock(mutex_[name]);
|
||||||
|
return global_current_count_[name].size() == global_threshold_count_[name];
|
||||||
|
} else {
|
||||||
|
CountReachThresholdRequest count_reach_threashold_req;
|
||||||
|
count_reach_threashold_req.set_name(name);
|
||||||
|
|
||||||
|
std::shared_ptr<std::vector<unsigned char>> query_cnt_enough_rsp_msg = nullptr;
|
||||||
|
if (!communicator_->SendPbRequest(count_reach_threashold_req, counting_server_rank_,
|
||||||
|
core::TcpUserCommand::kReachThreshold, &query_cnt_enough_rsp_msg)) {
|
||||||
|
MS_LOG(ERROR) << "Sending querying whether count reaches threshold message to leader server failed for " << name;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
CountReachThresholdResponse count_reach_threashold_rsp;
|
||||||
|
count_reach_threashold_rsp.ParseFromArray(query_cnt_enough_rsp_msg->data(), query_cnt_enough_rsp_msg->size());
|
||||||
|
return count_reach_threashold_rsp.is_enough();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void DistributedCountService::ResetCounter(const std::string &name) {
|
||||||
|
if (local_rank_ == counting_server_rank_) {
|
||||||
|
MS_LOG(INFO) << "Leader server reset count for " << name;
|
||||||
|
global_current_count_[name].clear();
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DistributedCountService::RegisterCallback() {
|
||||||
|
if (local_rank_ == counting_server_rank_) {
|
||||||
|
communicator_->RegisterMsgCallBack(
|
||||||
|
"count", std::bind(&DistributedCountService::HandleCountRequest, this, std::placeholders::_1));
|
||||||
|
communicator_->RegisterMsgCallBack(
|
||||||
|
"countReachThreshold",
|
||||||
|
std::bind(&DistributedCountService::HandleCountReachThresholdRequest, this, std::placeholders::_1));
|
||||||
|
}
|
||||||
|
|
||||||
|
// The callback of first/last event must be set in both leader server and follower servers.
|
||||||
|
communicator_->RegisterMsgCallBack(
|
||||||
|
"counterEvent", std::bind(&DistributedCountService::HandleCounterEvent, this, std::placeholders::_1));
|
||||||
|
}
|
||||||
|
|
||||||
|
void DistributedCountService::HandleCountRequest(const std::shared_ptr<core::MessageHandler> &message) {
|
||||||
|
if (message == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Message is nullptr.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
CountRequest report_count_req;
|
||||||
|
report_count_req.ParseFromArray(message->data(), message->len());
|
||||||
|
const std::string &name = report_count_req.name();
|
||||||
|
const std::string &id = report_count_req.id();
|
||||||
|
|
||||||
|
CountResponse count_rsp;
|
||||||
|
std::unique_lock<std::mutex> lock(mutex_[name]);
|
||||||
|
// If leader server has no counter for the name registered, return an error.
|
||||||
|
if (global_threshold_count_.count(name) == 0) {
|
||||||
|
std::string reason = "Counter for " + name + " is not registered.";
|
||||||
|
count_rsp.set_result(false);
|
||||||
|
count_rsp.set_reason(reason);
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(), message);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If leader server already has enough count for the name, return an error.
|
||||||
|
if (global_current_count_[name].size() >= global_threshold_count_[name]) {
|
||||||
|
std::string reason =
|
||||||
|
"Count for " + name + " is already enough. Threshold count is " + std::to_string(global_threshold_count_[name]);
|
||||||
|
count_rsp.set_result(false);
|
||||||
|
count_rsp.set_reason(reason);
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(), message);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert the id for the counter, which means the count for the name is increased.
|
||||||
|
MS_LOG(INFO) << "Leader server increase count for " << name << " of " << id;
|
||||||
|
global_current_count_[name].insert(id);
|
||||||
|
TriggerCounterEvent(name);
|
||||||
|
count_rsp.set_result(true);
|
||||||
|
count_rsp.set_reason("success");
|
||||||
|
communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(), message);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DistributedCountService::HandleCountReachThresholdRequest(const std::shared_ptr<core::MessageHandler> &message) {
|
||||||
|
if (message == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Message is nullptr.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
CountReachThresholdRequest count_reach_threashold_req;
|
||||||
|
count_reach_threashold_req.ParseFromArray(message->data(), message->len());
|
||||||
|
const std::string &name = count_reach_threashold_req.name();
|
||||||
|
|
||||||
|
std::unique_lock<std::mutex> lock(mutex_[name]);
|
||||||
|
if (global_threshold_count_.count(name) == 0) {
|
||||||
|
MS_LOG(ERROR) << "Counter for " << name << " is not registered.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
CountReachThresholdResponse count_reach_threashold_rsp;
|
||||||
|
count_reach_threashold_rsp.set_is_enough(global_current_count_[name].size() == global_threshold_count_[name]);
|
||||||
|
communicator_->SendResponse(count_reach_threashold_rsp.SerializeAsString().data(),
|
||||||
|
count_reach_threashold_rsp.SerializeAsString().size(), message);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DistributedCountService::HandleCounterEvent(const std::shared_ptr<core::MessageHandler> &message) {
|
||||||
|
if (message == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Message is nullptr.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Respond as soon as possible so the leader server won't wait for each follower servers to finish calling the
|
||||||
|
// callbacks.
|
||||||
|
std::string couter_event_rsp_msg = "success";
|
||||||
|
communicator_->SendResponse(couter_event_rsp_msg.data(), couter_event_rsp_msg.size(), message);
|
||||||
|
|
||||||
|
CounterEvent counter_event;
|
||||||
|
counter_event.ParseFromArray(message->data(), message->len());
|
||||||
|
const auto &type = counter_event.type();
|
||||||
|
const auto &name = counter_event.name();
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "Rank " << local_rank_ << " do counter event " << type << " for " << name;
|
||||||
|
if (type == CounterEventType::FIRST_CNT) {
|
||||||
|
counter_handlers_[name].first_count_handler(message);
|
||||||
|
} else if (type == CounterEventType::LAST_CNT) {
|
||||||
|
counter_handlers_[name].last_count_handler(message);
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << "DistributedCountService event type " << type << " is invalid.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DistributedCountService::TriggerCounterEvent(const std::string &name) {
|
||||||
|
MS_LOG(INFO) << "Current count for " << name << " is " << global_current_count_[name].size()
|
||||||
|
<< ", threshold count is " << global_threshold_count_[name];
|
||||||
|
// The threshold count may be 1 so the first and last count event should be both activated.
|
||||||
|
if (global_current_count_[name].size() == 1) {
|
||||||
|
TriggerFirstCountEvent(name);
|
||||||
|
}
|
||||||
|
if (global_current_count_[name].size() == global_threshold_count_[name]) {
|
||||||
|
TriggerLastCountEvent(name);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DistributedCountService::TriggerFirstCountEvent(const std::string &name) {
|
||||||
|
MS_LOG(INFO) << "Activating first count event for " << name;
|
||||||
|
CounterEvent first_count_event;
|
||||||
|
first_count_event.set_type(CounterEventType::FIRST_CNT);
|
||||||
|
first_count_event.set_name(name);
|
||||||
|
|
||||||
|
// Broadcast to all follower servers.
|
||||||
|
for (uint32_t i = 1; i < server_num_; i++) {
|
||||||
|
if (!communicator_->SendPbRequest(first_count_event, i, core::TcpUserCommand::kCounterEvent)) {
|
||||||
|
MS_LOG(ERROR) << "Activating first count event to server " << i << " failed.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Leader server directly calls the callback.
|
||||||
|
counter_handlers_[name].first_count_handler(nullptr);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DistributedCountService::TriggerLastCountEvent(const std::string &name) {
|
||||||
|
MS_LOG(INFO) << "Activating last count event for " << name;
|
||||||
|
CounterEvent last_count_event;
|
||||||
|
last_count_event.set_type(CounterEventType::LAST_CNT);
|
||||||
|
last_count_event.set_name(name);
|
||||||
|
|
||||||
|
// Broadcast to all follower servers.
|
||||||
|
for (uint32_t i = 1; i < server_num_; i++) {
|
||||||
|
if (!communicator_->SendPbRequest(last_count_event, i, core::TcpUserCommand::kCounterEvent)) {
|
||||||
|
MS_LOG(ERROR) << "Activating last count event to server " << i << " failed.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Leader server directly calls the callback.
|
||||||
|
counter_handlers_[name].last_count_handler(nullptr);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
} // namespace server
|
||||||
|
} // namespace ps
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,126 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_COUNT_SERVICE_H_
|
||||||
|
#define MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_COUNT_SERVICE_H_
|
||||||
|
|
||||||
|
#include <set>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include "proto/ps.pb.h"
|
||||||
|
#include "ps/server/common.h"
|
||||||
|
#include "ps/core/server_node.h"
|
||||||
|
#include "ps/core/communicator/tcp_communicator.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ps {
|
||||||
|
namespace server {
|
||||||
|
// The callbacks for the first count and last count event.
|
||||||
|
typedef struct {
|
||||||
|
MessageCallback first_count_handler;
|
||||||
|
MessageCallback last_count_handler;
|
||||||
|
} CounterHandlers;
|
||||||
|
|
||||||
|
// DistributedCountService is used for counting in the server cluster dimension. It's used for counting of rounds,
|
||||||
|
// aggregation counting, etc.
|
||||||
|
|
||||||
|
// The counting could be called by any server, but only one server has the information
|
||||||
|
// of the cluster count and we mark this server as the counting server. Other servers must communicate with this
|
||||||
|
// counting server to increase/query count number.
|
||||||
|
|
||||||
|
// On the first count or last count event, DistributedCountService on the counting server triggers the event on other
|
||||||
|
// servers by sending counter event commands. This is for the purpose of keeping server cluster's consistency.
|
||||||
|
class DistributedCountService {
|
||||||
|
public:
|
||||||
|
static DistributedCountService &GetInstance() {
|
||||||
|
static DistributedCountService instance;
|
||||||
|
return instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize counter service with the server node because communication is needed.
|
||||||
|
void Initialize(const std::shared_ptr<core::ServerNode> &server_node, uint32_t counting_server_rank);
|
||||||
|
|
||||||
|
// Register counter to the counting server for the name with its threshold count in server cluster dimension and
|
||||||
|
// first/last count event callbacks.
|
||||||
|
void RegisterCounter(const std::string &name, size_t global_threshold_count, const CounterHandlers &counter_handlers);
|
||||||
|
|
||||||
|
// Report a count to the counting server. Parameter 'id' is in case of repeated counting.
|
||||||
|
bool Count(const std::string &name, const std::string &id);
|
||||||
|
|
||||||
|
// Query whether the count reaches the threshold count for the name. If the count is the same as the threshold count,
|
||||||
|
// this method returns true.
|
||||||
|
bool CountReachThreshold(const std::string &name);
|
||||||
|
|
||||||
|
// Reset the count of the name to 0.
|
||||||
|
void ResetCounter(const std::string &name);
|
||||||
|
|
||||||
|
// Returns the server rank because in some cases the callers use this rank as the 'id' for method
|
||||||
|
// Count.
|
||||||
|
uint32_t local_rank() { return local_rank_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
DistributedCountService() = default;
|
||||||
|
~DistributedCountService() = default;
|
||||||
|
DistributedCountService(const DistributedCountService &) = delete;
|
||||||
|
DistributedCountService &operator=(const DistributedCountService &) = delete;
|
||||||
|
|
||||||
|
// Register callbacks of the counting server to handle messages sent by the other servers.
|
||||||
|
void RegisterCallback();
|
||||||
|
|
||||||
|
// Callback for the reporting count message from other servers. Only counting server will call this method.
|
||||||
|
void HandleCountRequest(const std::shared_ptr<core::MessageHandler> &message);
|
||||||
|
|
||||||
|
// Callback for the querying whether threshold count is reached message from other servers. Only counting
|
||||||
|
// server will call this method.
|
||||||
|
void HandleCountReachThresholdRequest(const std::shared_ptr<core::MessageHandler> &message);
|
||||||
|
|
||||||
|
// Callback for the first/last event message from the counting server. Only other servers will call this
|
||||||
|
// method.
|
||||||
|
void HandleCounterEvent(const std::shared_ptr<core::MessageHandler> &message);
|
||||||
|
|
||||||
|
// Call the callbacks when the first/last count event is triggered.
|
||||||
|
void TriggerCounterEvent(const std::string &name);
|
||||||
|
void TriggerFirstCountEvent(const std::string &name);
|
||||||
|
void TriggerLastCountEvent(const std::string &name);
|
||||||
|
|
||||||
|
// Members for the communication between counting server and other servers.
|
||||||
|
std::shared_ptr<core::ServerNode> server_node_;
|
||||||
|
std::shared_ptr<core::TcpCommunicator> communicator_;
|
||||||
|
uint32_t local_rank_;
|
||||||
|
uint32_t server_num_;
|
||||||
|
|
||||||
|
// Only one server will be set to do the real counting.
|
||||||
|
uint32_t counting_server_rank_;
|
||||||
|
|
||||||
|
// Key: name, e.g, startFLJob, updateModel, push.
|
||||||
|
// Value: a set of id without repeatation because each work may report multiple times.
|
||||||
|
std::unordered_map<std::string, std::set<std::string>> global_current_count_;
|
||||||
|
|
||||||
|
// Key: name, e.g, StartFLJobCount.
|
||||||
|
// Value: global threshold count in the server cluster dimension for this name.
|
||||||
|
std::unordered_map<std::string, size_t> global_threshold_count_;
|
||||||
|
|
||||||
|
// First/last count event callbacks of the name.
|
||||||
|
std::unordered_map<std::string, CounterHandlers> counter_handlers_;
|
||||||
|
|
||||||
|
// Because the count is increased/queried conccurently, we must ensure the operations are threadsafe.
|
||||||
|
std::unordered_map<std::string, std::mutex> mutex_;
|
||||||
|
};
|
||||||
|
} // namespace server
|
||||||
|
} // namespace ps
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_COUNT_SERVICE_H_
|
|
@ -0,0 +1,201 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "ps/server/distributed_metadata_store.h"
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ps {
|
||||||
|
namespace server {
|
||||||
|
void DistributedMetadataStore::Initialize(const std::shared_ptr<core::ServerNode> &server_node) {
|
||||||
|
server_node_ = server_node;
|
||||||
|
MS_EXCEPTION_IF_NULL(server_node);
|
||||||
|
|
||||||
|
communicator_ =
|
||||||
|
std::dynamic_pointer_cast<core::TcpCommunicator>(server_node_->GetOrCreateTcpComm("", 0, 0, 0, nullptr));
|
||||||
|
MS_EXCEPTION_IF_NULL(communicator_);
|
||||||
|
|
||||||
|
local_rank_ = server_node_->rank_id();
|
||||||
|
server_num_ = PSContext::instance()->initial_server_num();
|
||||||
|
|
||||||
|
InitHashRing();
|
||||||
|
RegisterCallback();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DistributedMetadataStore::RegisterMetadata(const std::string &name, const PBMetadata &meta) {
|
||||||
|
if (router_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "The consistent hash ring is not initialized yet.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t stored_rank = router_->Find(name);
|
||||||
|
if (local_rank_ == stored_rank) {
|
||||||
|
if (metadata_.count(name) != 0) {
|
||||||
|
MS_LOG(ERROR) << "The metadata for " << name << " is already registered.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "Rank " << local_rank_ << " register storage for metadata " << name;
|
||||||
|
metadata_[name] = meta;
|
||||||
|
mutex_[name];
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DistributedMetadataStore::ResetMetadata(const std::string &name) {
|
||||||
|
if (router_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "The consistent hash ring is not initialized yet.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t stored_rank = router_->Find(name);
|
||||||
|
if (local_rank_ == stored_rank) {
|
||||||
|
if (metadata_.count(name) == 0) {
|
||||||
|
MS_LOG(ERROR) << "The metadata for " << name << " is not registered.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "Rank " << local_rank_ << " reset metadata for " << name;
|
||||||
|
std::unique_lock<std::mutex> lock(mutex_[name]);
|
||||||
|
PBMetadata empty_meta;
|
||||||
|
metadata_[name] = empty_meta;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBMetadata &meta) {
|
||||||
|
if (router_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "The consistent hash ring is not initialized yet.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t stored_rank = router_->Find(name);
|
||||||
|
MS_LOG(INFO) << "Rank " << local_rank_ << " update value for " << name << " which is stored in rank " << stored_rank;
|
||||||
|
if (local_rank_ == stored_rank) {
|
||||||
|
if (!DoUpdateMetadata(name, meta)) {
|
||||||
|
MS_LOG(ERROR) << "Updating meta data failed.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
PBMetadataWithName metadata_with_name;
|
||||||
|
metadata_with_name.set_name(name);
|
||||||
|
*metadata_with_name.mutable_metadata() = meta;
|
||||||
|
if (!communicator_->SendPbRequest(metadata_with_name, stored_rank, core::TcpUserCommand::kUpdateMetadata)) {
|
||||||
|
MS_LOG(ERROR) << "Sending updating metadata message to server " << stored_rank << " failed.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
PBMetadata DistributedMetadataStore::GetMetadata(const std::string &name) {
|
||||||
|
if (router_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "The consistent hash ring is not initialized yet.";
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t stored_rank = router_->Find(name);
|
||||||
|
MS_LOG(INFO) << "Rank " << local_rank_ << " get metadata for " << name << " which is stored in rank " << stored_rank;
|
||||||
|
if (local_rank_ == stored_rank) {
|
||||||
|
std::unique_lock<std::mutex> lock(mutex_[name]);
|
||||||
|
return metadata_[name];
|
||||||
|
} else {
|
||||||
|
GetMetadataRequest get_metadata_req;
|
||||||
|
get_metadata_req.set_name(name);
|
||||||
|
PBMetadata get_metadata_rsp;
|
||||||
|
|
||||||
|
std::shared_ptr<std::vector<unsigned char>> get_meta_rsp_msg = nullptr;
|
||||||
|
if (!communicator_->SendPbRequest(get_metadata_req, stored_rank, core::TcpUserCommand::kGetMetadata,
|
||||||
|
&get_meta_rsp_msg)) {
|
||||||
|
MS_LOG(ERROR) << "Sending getting metadata message to server " << stored_rank << " failed.";
|
||||||
|
return get_metadata_rsp;
|
||||||
|
}
|
||||||
|
get_metadata_rsp.ParseFromArray(get_meta_rsp_msg->data(), get_meta_rsp_msg->size());
|
||||||
|
return get_metadata_rsp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void DistributedMetadataStore::InitHashRing() {
|
||||||
|
router_ = std::make_shared<ConsistentHashRing>(32);
|
||||||
|
MS_EXCEPTION_IF_NULL(router_);
|
||||||
|
for (uint32_t i = 0; i < server_num_; i++) {
|
||||||
|
bool ret = router_->Insert(i);
|
||||||
|
if (!ret) {
|
||||||
|
MS_LOG(EXCEPTION) << "Add node " << i << " to router of meta storage failed.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DistributedMetadataStore::RegisterCallback() {
|
||||||
|
communicator_->RegisterMsgCallBack(
|
||||||
|
"updateMetadata", std::bind(&DistributedMetadataStore::HandleUpdateMetadataRequest, this, std::placeholders::_1));
|
||||||
|
communicator_->RegisterMsgCallBack(
|
||||||
|
"getMetadata", std::bind(&DistributedMetadataStore::HandleGetMetadataRequest, this, std::placeholders::_1));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr<core::MessageHandler> &message) {
|
||||||
|
if (message == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Message is nullptr.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
PBMetadataWithName meta_with_name;
|
||||||
|
meta_with_name.ParseFromArray(message->data(), message->len());
|
||||||
|
const std::string &name = meta_with_name.name();
|
||||||
|
MS_LOG(INFO) << "Update metadata for " << name;
|
||||||
|
|
||||||
|
std::string update_meta_rsp_msg;
|
||||||
|
if (!DoUpdateMetadata(name, meta_with_name.metadata())) {
|
||||||
|
update_meta_rsp_msg = "Updating meta data failed.";
|
||||||
|
} else {
|
||||||
|
update_meta_rsp_msg = "Success";
|
||||||
|
}
|
||||||
|
communicator_->SendResponse(update_meta_rsp_msg.data(), update_meta_rsp_msg.size(), message);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DistributedMetadataStore::HandleGetMetadataRequest(const std::shared_ptr<core::MessageHandler> &message) {
|
||||||
|
if (message == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Message is nullptr.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
GetMetadataRequest get_metadata_req;
|
||||||
|
get_metadata_req.ParseFromArray(message->data(), message->len());
|
||||||
|
const std::string &name = get_metadata_req.name();
|
||||||
|
MS_LOG(INFO) << "Getting metadata for " << name;
|
||||||
|
|
||||||
|
std::unique_lock<std::mutex> lock(mutex_[name]);
|
||||||
|
PBMetadata stored_meta = metadata_[name];
|
||||||
|
std::string getting_meta_rsp_msg = stored_meta.SerializeAsString();
|
||||||
|
communicator_->SendResponse(getting_meta_rsp_msg.data(), getting_meta_rsp_msg.size(), message);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const PBMetadata &meta) {
|
||||||
|
std::unique_lock<std::mutex> lock(mutex_[name]);
|
||||||
|
metadata_[name] = meta;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} // namespace server
|
||||||
|
} // namespace ps
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,101 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_META_STORE_H_
|
||||||
|
#define MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_META_STORE_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include "proto/ps.pb.h"
|
||||||
|
#include "ps/server/common.h"
|
||||||
|
#include "ps/core/server_node.h"
|
||||||
|
#include "ps/core/communicator/tcp_communicator.h"
|
||||||
|
#include "ps/server/consistent_hash_ring.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ps {
|
||||||
|
namespace server {
|
||||||
|
// This class is used for distributed metadata storage using consistent hash. All metadata is distributedly
|
||||||
|
// stored in all servers. Caller doesn't need to know which server stores the metadata. It only needs to know what kind
|
||||||
|
// of operations should be done to the metadata.
|
||||||
|
|
||||||
|
// The metadata stored in the server is in protobuffer format because it's easy for serializing and communicating. The
|
||||||
|
// type of the protobuffer struct is decided by the caller using protobuffer's API.
|
||||||
|
class DistributedMetadataStore {
|
||||||
|
public:
|
||||||
|
static DistributedMetadataStore &GetInstance() {
|
||||||
|
static DistributedMetadataStore instance;
|
||||||
|
return instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize metadata storage with the server node because communication is needed.
|
||||||
|
void Initialize(const std::shared_ptr<core::ServerNode> &server_node);
|
||||||
|
|
||||||
|
// Register metadata for the name with the initial value. This method should be only called once for each name.
|
||||||
|
void RegisterMetadata(const std::string &name, const PBMetadata &meta);
|
||||||
|
|
||||||
|
// Reset the metadata value for the name.
|
||||||
|
void ResetMetadata(const std::string &name);
|
||||||
|
|
||||||
|
// Update the metadata for the name.
|
||||||
|
void UpdateMetadata(const std::string &name, const PBMetadata &meta);
|
||||||
|
|
||||||
|
// Get the metadata for the name.
|
||||||
|
PBMetadata GetMetadata(const std::string &name);
|
||||||
|
|
||||||
|
private:
|
||||||
|
DistributedMetadataStore() = default;
|
||||||
|
~DistributedMetadataStore() = default;
|
||||||
|
DistributedMetadataStore(const DistributedMetadataStore &) = delete;
|
||||||
|
DistributedMetadataStore &operator=(const DistributedMetadataStore &) = delete;
|
||||||
|
|
||||||
|
// Initialize the consistent hash ring for distributed storage.
|
||||||
|
void InitHashRing();
|
||||||
|
|
||||||
|
// Register callbacks for the server to handle update/get metadata messages from other servers.
|
||||||
|
void RegisterCallback();
|
||||||
|
|
||||||
|
// Callback for updating metadata request sent to the server.
|
||||||
|
void HandleUpdateMetadataRequest(const std::shared_ptr<core::MessageHandler> &message);
|
||||||
|
|
||||||
|
// Callback for getting metadata request sent to the server.
|
||||||
|
void HandleGetMetadataRequest(const std::shared_ptr<core::MessageHandler> &message);
|
||||||
|
|
||||||
|
// Do updating metadata in the server where the metadata for the name is stored.
|
||||||
|
bool DoUpdateMetadata(const std::string &name, const PBMetadata &meta);
|
||||||
|
|
||||||
|
// Members for the communication between servers.
|
||||||
|
std::shared_ptr<core::ServerNode> server_node_;
|
||||||
|
std::shared_ptr<core::TcpCommunicator> communicator_;
|
||||||
|
uint32_t local_rank_;
|
||||||
|
uint32_t server_num_;
|
||||||
|
|
||||||
|
// Consistent hash ring. This is used for DistributedMetadataStore to find which server node the meta data is stored.
|
||||||
|
std::shared_ptr<ConsistentHashRing> router_;
|
||||||
|
|
||||||
|
// We store metadata which is serialized by ProtoBuffer so that data storage and data transmission API is easy to use.
|
||||||
|
// Key: data name.
|
||||||
|
// Value: ProtoBuffer Struct.
|
||||||
|
std::unordered_map<std::string, PBMetadata> metadata_;
|
||||||
|
|
||||||
|
// Because the metadata is read/written conccurently, we must ensure the operations are threadsafe.
|
||||||
|
std::unordered_map<std::string, std::mutex> mutex_;
|
||||||
|
};
|
||||||
|
} // namespace server
|
||||||
|
} // namespace ps
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_META_STORE_H_
|
|
@ -169,7 +169,7 @@ bool Executor::HandleOverwriteWeightsByKey(const std::map<std::string, Address>
|
||||||
}
|
}
|
||||||
|
|
||||||
AddressPtr Executor::HandlePull(const std::string ¶m_name) {
|
AddressPtr Executor::HandlePull(const std::string ¶m_name) {
|
||||||
MS_LOG(INFO) << "Handle blocking pull msg for parameter " << param_name;
|
MS_LOG(INFO) << "Handle blocking pull message for parameter " << param_name;
|
||||||
if (param_aggrs_.count(param_name) == 0) {
|
if (param_aggrs_.count(param_name) == 0) {
|
||||||
MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server.";
|
MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server.";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -193,11 +193,6 @@ AddressPtr Executor::HandlePull(const std::string ¶m_name) {
|
||||||
return addr;
|
return addr;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::map<std::string, AddressPtr> Executor::HandleAsyncGetModel() {
|
|
||||||
std::unique_lock<std::mutex> lock(model_mutex_);
|
|
||||||
return GetModel();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::map<std::string, AddressPtr> Executor::HandleGetWeightsByKey(const std::vector<std::string> ¶m_names) {
|
std::map<std::string, AddressPtr> Executor::HandleGetWeightsByKey(const std::vector<std::string> ¶m_names) {
|
||||||
std::map<std::string, AddressPtr> weights;
|
std::map<std::string, AddressPtr> weights;
|
||||||
for (const auto ¶m_name : param_names) {
|
for (const auto ¶m_name : param_names) {
|
||||||
|
|
|
@ -63,10 +63,6 @@ class Executor {
|
||||||
// asynchronously.
|
// asynchronously.
|
||||||
bool HandleModelUpdateAsync(const std::map<std::string, UploadData> &feature_map);
|
bool HandleModelUpdateAsync(const std::map<std::string, UploadData> &feature_map);
|
||||||
|
|
||||||
// Called in asynchronous federated learning training mode. Returns whole model in key-value where key refers to the
|
|
||||||
// parameter name.
|
|
||||||
std::map<std::string, AddressPtr> HandleAsyncGetModel();
|
|
||||||
|
|
||||||
// Forcibly overwrite specific weights in overwriteWeights message.
|
// Forcibly overwrite specific weights in overwriteWeights message.
|
||||||
bool HandleOverwriteWeightsByKey(const std::map<std::string, Address> &feature_map);
|
bool HandleOverwriteWeightsByKey(const std::map<std::string, Address> &feature_map);
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,76 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "ps/server/iteration.h"
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
#include <numeric>
|
||||||
|
#include "ps/server/model_store.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ps {
|
||||||
|
namespace server {
|
||||||
|
Iteration::Iteration() : iteration_num_(1) { LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_); }
|
||||||
|
|
||||||
|
void Iteration::AddRound(const std::shared_ptr<Round> &round) {
|
||||||
|
MS_EXCEPTION_IF_NULL(round);
|
||||||
|
rounds_.push_back(round);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Iteration::InitRounds(const std::vector<std::shared_ptr<core::CommunicatorBase>> &communicators,
|
||||||
|
const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb) {
|
||||||
|
if (communicators.empty()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Communicators for rounds is empty.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::for_each(communicators.begin(), communicators.end(),
|
||||||
|
[&](const std::shared_ptr<core::CommunicatorBase> &communicator) {
|
||||||
|
for (auto &round : rounds_) {
|
||||||
|
if (round == nullptr) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
round->Initialize(communicator, timeout_cb, finish_iteration_cb);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// The time window for one iteration, which will be used in some round kernels.
|
||||||
|
size_t iteration_time_window =
|
||||||
|
std::accumulate(rounds_.begin(), rounds_.end(), 0,
|
||||||
|
[](size_t total, const std::shared_ptr<Round> &round) { return total + round->time_window(); });
|
||||||
|
LocalMetaStore::GetInstance().put_value(kCtxTotalTimeoutDuration, iteration_time_window);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Iteration::ProceedToNextIter() {
|
||||||
|
iteration_num_ = LocalMetaStore::GetInstance().curr_iter_num();
|
||||||
|
// Store the model for each iteration.
|
||||||
|
const auto &model = Executor::GetInstance().GetModel();
|
||||||
|
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
|
||||||
|
|
||||||
|
for (auto &round : rounds_) {
|
||||||
|
round->Reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
iteration_num_++;
|
||||||
|
LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
|
||||||
|
MS_LOG(INFO) << "Proceed to next iteration:" << iteration_num_ << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<std::shared_ptr<Round>> &Iteration::rounds() { return rounds_; }
|
||||||
|
} // namespace server
|
||||||
|
} // namespace ps
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,58 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_PS_SERVER_ITERATION_H_
|
||||||
|
#define MINDSPORE_CCSRC_PS_SERVER_ITERATION_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
#include "ps/core/communicator/communicator_base.h"
|
||||||
|
#include "ps/server/common.h"
|
||||||
|
#include "ps/server/round.h"
|
||||||
|
#include "ps/server/local_meta_store.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ps {
|
||||||
|
namespace server {
|
||||||
|
// In server's logic, Iteration is the minimum execution unit. For each execution, it consists of multiple kinds of
|
||||||
|
// Rounds, only after all the rounds are finished, this iteration is considered as completed.
|
||||||
|
class Iteration {
|
||||||
|
public:
|
||||||
|
Iteration();
|
||||||
|
~Iteration() = default;
|
||||||
|
|
||||||
|
// Add a round for the iteration. This method will be called multiple times for each round.
|
||||||
|
void AddRound(const std::shared_ptr<Round> &round);
|
||||||
|
|
||||||
|
// Initialize all the rounds in the iteration.
|
||||||
|
void InitRounds(const std::vector<std::shared_ptr<core::CommunicatorBase>> &communicators,
|
||||||
|
const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb);
|
||||||
|
|
||||||
|
// The server proceeds to the next iteration only after the last iteration finishes.
|
||||||
|
void ProceedToNextIter();
|
||||||
|
|
||||||
|
const std::vector<std::shared_ptr<Round>> &rounds();
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<std::shared_ptr<Round>> rounds_;
|
||||||
|
|
||||||
|
// Server's current iteration number.
|
||||||
|
size_t iteration_num_;
|
||||||
|
};
|
||||||
|
} // namespace server
|
||||||
|
} // namespace ps
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_PS_SERVER_ITERATION_H_
|
|
@ -0,0 +1,127 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "ps/server/kernel/round/round_kernel.h"
|
||||||
|
#include <mutex>
|
||||||
|
#include <queue>
|
||||||
|
#include <chrono>
|
||||||
|
#include <thread>
|
||||||
|
#include <utility>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ps {
|
||||||
|
namespace server {
|
||||||
|
namespace kernel {
|
||||||
|
RoundKernel::RoundKernel() : name_(""), current_count_(0), required_count_(0), error_reason_("") {
|
||||||
|
release_thread_ = std::thread([&]() {
|
||||||
|
while (true) {
|
||||||
|
std::unique_lock<std::mutex> release_lock(release_mtx_);
|
||||||
|
// Detect whether there's any data needs to be released every 100 milliseconds.
|
||||||
|
if (heap_data_to_release_.empty()) {
|
||||||
|
release_lock.unlock();
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
AddressPtr addr_ptr = heap_data_to_release_.front();
|
||||||
|
heap_data_to_release_.pop();
|
||||||
|
release_lock.unlock();
|
||||||
|
|
||||||
|
std::unique_lock<std::mutex> heap_data_lock(heap_data_mtx_);
|
||||||
|
if (heap_data_.count(addr_ptr) == 0) {
|
||||||
|
MS_LOG(ERROR) << "The data is not stored.";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// Manually release unique_ptr data.
|
||||||
|
heap_data_[addr_ptr].reset(nullptr);
|
||||||
|
heap_data_.erase(heap_data_.find(addr_ptr));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
release_thread_.detach();
|
||||||
|
}
|
||||||
|
|
||||||
|
void RoundKernel::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message) { return; }
|
||||||
|
|
||||||
|
void RoundKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) { return; }
|
||||||
|
|
||||||
|
void RoundKernel::StopTimer() {
|
||||||
|
if (stop_timer_cb_) {
|
||||||
|
stop_timer_cb_();
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void RoundKernel::FinishIteration() {
|
||||||
|
if (finish_iteration_cb_) {
|
||||||
|
finish_iteration_cb_();
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void RoundKernel::Release(AddressPtr addr_ptr) {
|
||||||
|
if (addr_ptr == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Data to be released is empty.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
std::unique_lock<std::mutex> lock(release_mtx_);
|
||||||
|
heap_data_to_release_.push(addr_ptr);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void RoundKernel::set_name(const std::string &name) { name_ = name; }
|
||||||
|
|
||||||
|
void RoundKernel::set_stop_timer_cb(StopTimerCb timer_stopper) { stop_timer_cb_ = timer_stopper; }
|
||||||
|
|
||||||
|
void RoundKernel::set_finish_iteration_cb(FinishIterCb finish_iteration_cb) {
|
||||||
|
finish_iteration_cb_ = finish_iteration_cb;
|
||||||
|
}
|
||||||
|
|
||||||
|
void RoundKernel::GenerateOutput(const std::vector<AddressPtr> &outputs, void *data, size_t len) {
|
||||||
|
if (data == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "The data is nullptr.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (outputs.empty()) {
|
||||||
|
MS_LOG(ERROR) << "Generating output failed. Outputs size is empty.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<unsigned char[]> output_data = std::make_unique<unsigned char[]>(len);
|
||||||
|
if (output_data == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Output data is nullptr.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t dst_size = len;
|
||||||
|
int ret = memcpy_s(output_data.get(), dst_size, data, len);
|
||||||
|
if (ret != 0) {
|
||||||
|
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
outputs[0]->addr = output_data.get();
|
||||||
|
outputs[0]->size = len;
|
||||||
|
|
||||||
|
std::unique_lock<std::mutex> lock(heap_data_mtx_);
|
||||||
|
heap_data_.insert(std::make_pair(outputs[0], std::move(output_data)));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace server
|
||||||
|
} // namespace ps
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,130 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_
|
||||||
|
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <mutex>
|
||||||
|
#include <queue>
|
||||||
|
#include <utility>
|
||||||
|
#include <chrono>
|
||||||
|
#include <thread>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include "backend/kernel_compiler/common_utils.h"
|
||||||
|
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||||
|
#include "ps/server/common.h"
|
||||||
|
#include "ps/server/local_meta_store.h"
|
||||||
|
#include "ps/server/distributed_count_service.h"
|
||||||
|
#include "ps/server/distributed_metadata_store.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ps {
|
||||||
|
namespace server {
|
||||||
|
namespace kernel {
|
||||||
|
// RoundKernel contains the main logic of server handling messages from workers. One iteration has multiple round
|
||||||
|
// kernels to represent the process. They receive and parse messages from the server communication module. After
|
||||||
|
// handling these messages, round kernels allocate response data and send it back.
|
||||||
|
|
||||||
|
// For example, the main process of federated learning is:
|
||||||
|
// startFLJob round->updateModel round->getModel round.
|
||||||
|
class RoundKernel : virtual public CPUKernel {
|
||||||
|
public:
|
||||||
|
RoundKernel();
|
||||||
|
virtual ~RoundKernel() = default;
|
||||||
|
|
||||||
|
// RoundKernel doesn't use InitKernel method of base class CPUKernel to initialize. So implementation of this
|
||||||
|
// inherited method is empty.
|
||||||
|
void InitKernel(const CNodePtr &kernel_node) override {}
|
||||||
|
|
||||||
|
// Initialize RoundKernel with threshold_count which means that for every iteration, this round needs threshold_count
|
||||||
|
// messages.
|
||||||
|
virtual void InitKernel(size_t threshold_count) = 0;
|
||||||
|
|
||||||
|
// Launch the round kernel logic to handle the message passed by the communication module.
|
||||||
|
virtual bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
|
const std::vector<AddressPtr> &outputs) = 0;
|
||||||
|
|
||||||
|
// The callbacks when first message and last message for this round kernel is received.
|
||||||
|
// These methods is called by class DistributedCountService and triggered by leader server(Rank 0).
|
||||||
|
// virtual void OnFirstCountEvent(std::shared_ptr<core::MessageHandler> message);
|
||||||
|
// virtual void OnLastCnt(std::shared_ptr<core::MessageHandler> message);
|
||||||
|
|
||||||
|
// Some rounds could be stateful in a iteration. Reset method resets the status of this round.
|
||||||
|
virtual bool Reset() = 0;
|
||||||
|
|
||||||
|
// The counter event handlers for DistributedCountService.
|
||||||
|
virtual void OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message);
|
||||||
|
virtual void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message);
|
||||||
|
|
||||||
|
// Called when this round is finished. This round timer's Stop method will be called.
|
||||||
|
void StopTimer();
|
||||||
|
|
||||||
|
// Called after this iteration(including all rounds) is finished. All rounds' Reset method will
|
||||||
|
// be called.
|
||||||
|
void FinishIteration();
|
||||||
|
|
||||||
|
// Release the response data allocated inside the round kernel.
|
||||||
|
// Server framework must call this after the response data is sent back.
|
||||||
|
void Release(AddressPtr addr_ptr);
|
||||||
|
|
||||||
|
// Set round kernel name, which could be used in round kernel's methods.
|
||||||
|
void set_name(const std::string &name);
|
||||||
|
|
||||||
|
// Set callbacks to be called under certain triggered conditions.
|
||||||
|
void set_stop_timer_cb(StopTimerCb timer_stopper);
|
||||||
|
void set_finish_iteration_cb(FinishIterCb finish_iteration_cb);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
// Generating response data of this round. The data is allocated on the heap to ensure it's not released before sent
|
||||||
|
// back to worker.
|
||||||
|
void GenerateOutput(const std::vector<AddressPtr> &outputs, void *data, size_t len);
|
||||||
|
|
||||||
|
// Round kernel's name.
|
||||||
|
std::string name_;
|
||||||
|
|
||||||
|
// The current received message count for this round in this iteration.
|
||||||
|
size_t current_count_;
|
||||||
|
|
||||||
|
// The required received message count for this round in one iteration.
|
||||||
|
size_t required_count_;
|
||||||
|
|
||||||
|
// The reason causes the error in this round kernel.
|
||||||
|
std::string error_reason_;
|
||||||
|
|
||||||
|
StopTimerCb stop_timer_cb_;
|
||||||
|
FinishIterCb finish_iteration_cb_;
|
||||||
|
|
||||||
|
// Members below are used for allocating and releasing response data on the heap.
|
||||||
|
|
||||||
|
// To ensure the performance, we use another thread to release data on the heap. So the operation on the data should
|
||||||
|
// be threadsafe.
|
||||||
|
std::thread release_thread_;
|
||||||
|
|
||||||
|
// Data needs to be released and its mutex;
|
||||||
|
std::mutex release_mtx_;
|
||||||
|
std::queue<AddressPtr> heap_data_to_release_;
|
||||||
|
std::mutex heap_data_mtx_;
|
||||||
|
std::unordered_map<AddressPtr, std::unique_ptr<unsigned char[]>> heap_data_;
|
||||||
|
};
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace server
|
||||||
|
} // namespace ps
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_
|
|
@ -0,0 +1,44 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "ps/server/kernel/round/round_kernel_factory.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ps {
|
||||||
|
namespace server {
|
||||||
|
namespace kernel {
|
||||||
|
RoundKernelFactory &RoundKernelFactory::GetInstance() {
|
||||||
|
static RoundKernelFactory instance;
|
||||||
|
return instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
void RoundKernelFactory::Register(const std::string &name, RoundKernelCreator &&creator) {
|
||||||
|
name_to_creator_map_[name] = creator;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<RoundKernel> RoundKernelFactory::Create(const std::string &name) {
|
||||||
|
if (name_to_creator_map_.count(name) == 0) {
|
||||||
|
MS_LOG(ERROR) << "Round kernel " << name << " is not registered.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto kernel = name_to_creator_map_[name]();
|
||||||
|
kernel->set_name(name);
|
||||||
|
return kernel;
|
||||||
|
}
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace server
|
||||||
|
} // namespace ps
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,62 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_
|
||||||
|
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include "ps/server/common.h"
|
||||||
|
#include "ps/server/kernel/round/round_kernel.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ps {
|
||||||
|
namespace server {
|
||||||
|
namespace kernel {
|
||||||
|
using RoundKernelCreator = std::function<std::shared_ptr<RoundKernel>()>;
|
||||||
|
// Kernel factory of round kernels.
|
||||||
|
class RoundKernelFactory {
|
||||||
|
public:
|
||||||
|
static RoundKernelFactory &GetInstance();
|
||||||
|
void Register(const std::string &name, RoundKernelCreator &&creator);
|
||||||
|
std::shared_ptr<RoundKernel> Create(const std::string &name);
|
||||||
|
|
||||||
|
private:
|
||||||
|
RoundKernelFactory() = default;
|
||||||
|
~RoundKernelFactory() = default;
|
||||||
|
RoundKernelFactory(const RoundKernelFactory &) = delete;
|
||||||
|
RoundKernelFactory &operator=(const RoundKernelFactory &) = delete;
|
||||||
|
|
||||||
|
std::unordered_map<std::string, RoundKernelCreator> name_to_creator_map_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class RoundKernelRegister {
|
||||||
|
public:
|
||||||
|
RoundKernelRegister(const std::string &name, RoundKernelCreator &&creator) {
|
||||||
|
RoundKernelFactory::GetInstance().Register(name, std::move(creator));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#define REG_ROUND_KERNEL(NAME, CLASS) \
|
||||||
|
static_assert(std::is_base_of<RoundKernel, CLASS>::value, " must be base of RoundKernel"); \
|
||||||
|
static const RoundKernelRegister g_##NAME##_round_kernel_reg(#NAME, []() { return std::make_shared<CLASS>(); });
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace server
|
||||||
|
} // namespace ps
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_
|
|
@ -0,0 +1,192 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "ps/server/kernel/round/start_fl_job_kernel.h"
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ps {
|
||||||
|
namespace server {
|
||||||
|
namespace kernel {
|
||||||
|
void StartFLJobKernel::InitKernel(size_t) {
|
||||||
|
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
|
||||||
|
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||||
|
}
|
||||||
|
|
||||||
|
executor_ = &Executor::GetInstance();
|
||||||
|
MS_EXCEPTION_IF_NULL(executor_);
|
||||||
|
if (!executor_->initialized()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
PBMetadata devices_metas;
|
||||||
|
DistributedMetadataStore::GetInstance().RegisterMetadata(kCtxDeviceMetas, devices_metas);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool StartFLJobKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
|
const std::vector<AddressPtr> &outputs) {
|
||||||
|
MS_LOG(INFO) << "Launching StartFLJobKernel kernel.";
|
||||||
|
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||||
|
MS_LOG(ERROR) << "inputs or outputs size is invalid.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
void *req_data = inputs[0]->addr;
|
||||||
|
const std::shared_ptr<FBBuilder> &fbb = std::make_shared<FBBuilder>();
|
||||||
|
if (fbb == nullptr || req_data == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "FBBuilder builder or req_data is nullptr.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ReachThresholdForStartFLJob(fbb)) {
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const schema::RequestFLJob *start_fl_job_req = flatbuffers::GetRoot<schema::RequestFLJob>(req_data);
|
||||||
|
DeviceMeta device_meta = CreateDeviceMetadata(start_fl_job_req);
|
||||||
|
if (!ReadyForStartFLJob(fbb, device_meta)) {
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// If calling ReportCount before ReadyForStartFLJob, the result will be inconsistent if the device is not selected.
|
||||||
|
if (!CountForStartFLJob(fbb, start_fl_job_req)) {
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
StartFLJob(fbb, device_meta);
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool StartFLJobKernel::Reset() {
|
||||||
|
MS_LOG(INFO) << "Starting fl job kernel reset!";
|
||||||
|
StopTimer();
|
||||||
|
DistributedCountService::GetInstance().ResetCounter(name_);
|
||||||
|
DistributedMetadataStore::GetInstance().ResetMetadata(kCtxDeviceMetas);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool StartFLJobKernel::ReachThresholdForStartFLJob(const std::shared_ptr<FBBuilder> &fbb) {
|
||||||
|
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||||
|
std::string reason = "Current amount for startFLJob has reached the threshold. Please startFLJob later.";
|
||||||
|
BuildStartFLJobRsp(fbb, schema::ResponseCode_OutOfTime, reason, false,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
DeviceMeta StartFLJobKernel::CreateDeviceMetadata(const schema::RequestFLJob *start_fl_job_req) {
|
||||||
|
std::string fl_name = start_fl_job_req->fl_name()->str();
|
||||||
|
std::string fl_id = start_fl_job_req->fl_id()->str();
|
||||||
|
int data_size = start_fl_job_req->data_size();
|
||||||
|
MS_LOG(INFO) << "DeviceMeta fl_name:" << fl_name << ", fl_id:" << fl_id << ", data_size:" << data_size;
|
||||||
|
|
||||||
|
DeviceMeta device_meta;
|
||||||
|
device_meta.set_fl_name(fl_name);
|
||||||
|
device_meta.set_fl_id(fl_id);
|
||||||
|
device_meta.set_data_size(data_size);
|
||||||
|
return device_meta;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool StartFLJobKernel::ReadyForStartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta) {
|
||||||
|
bool ret = true;
|
||||||
|
std::string reason = "";
|
||||||
|
if (device_meta.data_size() < 1) {
|
||||||
|
reason = "FL job data size is not enough.";
|
||||||
|
ret = false;
|
||||||
|
}
|
||||||
|
if (!ret) {
|
||||||
|
BuildStartFLJobRsp(fbb, schema::ResponseCode_NotSelected, reason, false,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool StartFLJobKernel::CountForStartFLJob(const std::shared_ptr<FBBuilder> &fbb,
|
||||||
|
const schema::RequestFLJob *start_fl_job_req) {
|
||||||
|
if (!DistributedCountService::GetInstance().Count(name_, start_fl_job_req->fl_id()->str())) {
|
||||||
|
std::string reason = "startFLJob counting failed.";
|
||||||
|
BuildStartFLJobRsp(fbb, schema::ResponseCode_OutOfTime, reason, false,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void StartFLJobKernel::StartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta) {
|
||||||
|
PBMetadata metadata;
|
||||||
|
*metadata.mutable_device_meta() = device_meta;
|
||||||
|
DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxDeviceMetas, metadata);
|
||||||
|
|
||||||
|
std::map<std::string, AddressPtr> feature_maps = executor_->GetModel();
|
||||||
|
BuildStartFLJobRsp(fbb, schema::ResponseCode_SUCCEED, "success", true,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_), feature_maps);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
|
||||||
|
const std::string &reason, const bool is_selected,
|
||||||
|
const std::string &next_req_time,
|
||||||
|
std::map<std::string, AddressPtr> feature_maps) {
|
||||||
|
auto fbs_reason = fbb->CreateString(reason);
|
||||||
|
auto fbs_next_req_time = fbb->CreateString(next_req_time);
|
||||||
|
auto fbs_fl_name = fbb->CreateString(PSContext::instance()->fl_name());
|
||||||
|
|
||||||
|
schema::FLPlanBuilder fl_plan_builder(*(fbb.get()));
|
||||||
|
fl_plan_builder.add_fl_name(fbs_fl_name);
|
||||||
|
fl_plan_builder.add_iterations(PSContext::instance()->fl_iteration_num());
|
||||||
|
fl_plan_builder.add_epochs(PSContext::instance()->client_epoch_num());
|
||||||
|
fl_plan_builder.add_mini_batch(PSContext::instance()->client_batch_size());
|
||||||
|
auto fbs_fl_plan = fl_plan_builder.Finish();
|
||||||
|
|
||||||
|
std::vector<flatbuffers::Offset<schema::FeatureMap>> fbs_feature_maps;
|
||||||
|
for (auto feature_map : feature_maps) {
|
||||||
|
auto fbs_weight_fullname = fbb->CreateString(feature_map.first);
|
||||||
|
auto fbs_weight_data =
|
||||||
|
fbb->CreateVector(reinterpret_cast<float *>(feature_map.second->addr), feature_map.second->size / sizeof(float));
|
||||||
|
auto fbs_feature_map = schema::CreateFeatureMap(*(fbb.get()), fbs_weight_fullname, fbs_weight_data);
|
||||||
|
fbs_feature_maps.push_back(fbs_feature_map);
|
||||||
|
}
|
||||||
|
auto fbs_feature_maps_vector = fbb->CreateVector(fbs_feature_maps);
|
||||||
|
|
||||||
|
schema::ResponseFLJobBuilder rsp_fl_job_builder(*(fbb.get()));
|
||||||
|
rsp_fl_job_builder.add_retcode(retcode);
|
||||||
|
rsp_fl_job_builder.add_reason(fbs_reason);
|
||||||
|
rsp_fl_job_builder.add_iteration(LocalMetaStore::GetInstance().curr_iter_num());
|
||||||
|
rsp_fl_job_builder.add_is_selected(is_selected);
|
||||||
|
rsp_fl_job_builder.add_next_req_time(fbs_next_req_time);
|
||||||
|
rsp_fl_job_builder.add_fl_plan_config(fbs_fl_plan);
|
||||||
|
rsp_fl_job_builder.add_feature_map(fbs_feature_maps_vector);
|
||||||
|
auto rsp_fl_job = rsp_fl_job_builder.Finish();
|
||||||
|
fbb->Finish(rsp_fl_job);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
REG_ROUND_KERNEL(startFLJob, StartFLJobKernel)
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace server
|
||||||
|
} // namespace ps
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,74 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_START_FL_JOB_KERNEL_H_
|
||||||
|
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_START_FL_JOB_KERNEL_H_
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "ps/server/common.h"
|
||||||
|
#include "ps/server/executor.h"
|
||||||
|
#include "ps/server/kernel/round/round_kernel.h"
|
||||||
|
#include "ps/server/kernel/round/round_kernel_factory.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ps {
|
||||||
|
namespace server {
|
||||||
|
namespace kernel {
|
||||||
|
class StartFLJobKernel : public RoundKernel {
|
||||||
|
public:
|
||||||
|
StartFLJobKernel() = default;
|
||||||
|
~StartFLJobKernel() override = default;
|
||||||
|
|
||||||
|
void InitKernel(size_t threshold_count) override;
|
||||||
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
|
const std::vector<AddressPtr> &outputs) override;
|
||||||
|
bool Reset() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Returns whether the startFLJob count of this iteration has reached the threshold.
|
||||||
|
bool ReachThresholdForStartFLJob(const std::shared_ptr<FBBuilder> &fbb);
|
||||||
|
|
||||||
|
// The metadata of device will be stored and queried in updateModel round.
|
||||||
|
DeviceMeta CreateDeviceMetadata(const schema::RequestFLJob *start_fl_job_req);
|
||||||
|
|
||||||
|
// Returns whether the request is valid for startFLJob.For now, the condition is simple. We will add more conditions
|
||||||
|
// to device in later versions.
|
||||||
|
bool ReadyForStartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta);
|
||||||
|
|
||||||
|
// Distributed count service counts for startFLJob.
|
||||||
|
bool CountForStartFLJob(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestFLJob *start_fl_job_req);
|
||||||
|
|
||||||
|
void StartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta);
|
||||||
|
|
||||||
|
// Build response for startFLJob round no matter success or failure.
|
||||||
|
void BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
|
||||||
|
const std::string &reason, const bool is_selected, const std::string &next_req_time,
|
||||||
|
std::map<std::string, AddressPtr> feature_maps = {});
|
||||||
|
|
||||||
|
// The executor is for getting the initial model for startFLJob request.
|
||||||
|
Executor *executor_;
|
||||||
|
|
||||||
|
// The time window of one iteration.
|
||||||
|
size_t iteration_time_window_;
|
||||||
|
};
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace server
|
||||||
|
} // namespace ps
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_START_FL_JOB_KERNEL_H_
|
|
@ -14,30 +14,29 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "ps/server/local_meta_storage.h"
|
#include "ps/server/local_meta_store.h"
|
||||||
#include <string>
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace ps {
|
||||||
namespace server {
|
namespace server {
|
||||||
void LocalMetaStorage::remove_value(const std::string &name) {
|
void LocalMetaStore::remove_value(const std::string &name) {
|
||||||
std::unique_lock<std::mutex> lock(mtx_);
|
std::unique_lock<std::mutex> lock(mtx_);
|
||||||
if (key_to_meta_.count(name) != 0) {
|
if (key_to_meta_.count(name) != 0) {
|
||||||
key_to_meta_.erase(key_to_meta_.find(name));
|
key_to_meta_.erase(key_to_meta_.find(name));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool LocalMetaStorage::has_value(const std::string &name) {
|
bool LocalMetaStore::has_value(const std::string &name) {
|
||||||
std::unique_lock<std::mutex> lock(mtx_);
|
std::unique_lock<std::mutex> lock(mtx_);
|
||||||
return key_to_meta_.count(name) != 0;
|
return key_to_meta_.count(name) != 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
void LocalMetaStorage::set_curr_iter_num(size_t num) {
|
void LocalMetaStore::set_curr_iter_num(size_t num) {
|
||||||
std::unique_lock<std::mutex> lock(mtx_);
|
std::unique_lock<std::mutex> lock(mtx_);
|
||||||
curr_iter_num_ = num;
|
curr_iter_num_ = num;
|
||||||
}
|
}
|
||||||
|
|
||||||
const size_t LocalMetaStorage::curr_iter_num() {
|
const size_t LocalMetaStore::curr_iter_num() {
|
||||||
std::unique_lock<std::mutex> lock(mtx_);
|
std::unique_lock<std::mutex> lock(mtx_);
|
||||||
return curr_iter_num_;
|
return curr_iter_num_;
|
||||||
}
|
}
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORAGE_H_
|
#ifndef MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORE_H_
|
||||||
#define MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORAGE_H_
|
#define MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORE_H_
|
||||||
|
|
||||||
#include <any>
|
#include <any>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
|
@ -26,13 +26,13 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace ps {
|
||||||
namespace server {
|
namespace server {
|
||||||
// LocalMetaStorage class is used for metadata storage of this server process.
|
// LocalMetaStore class is used for metadata storage of this server process.
|
||||||
// For example, the current iteration number, time windows for round kernels, etc.
|
// For example, the current iteration number, time windows for round kernels, etc.
|
||||||
// LocalMetaStorage is threadsafe.
|
// LocalMetaStore is threadsafe.
|
||||||
class LocalMetaStorage {
|
class LocalMetaStore {
|
||||||
public:
|
public:
|
||||||
static LocalMetaStorage &GetInstance() {
|
static LocalMetaStore &GetInstance() {
|
||||||
static LocalMetaStorage instance;
|
static LocalMetaStore instance;
|
||||||
return instance;
|
return instance;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -43,7 +43,7 @@ class LocalMetaStorage {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
const T &value(const std::string &name) {
|
T value(const std::string &name) {
|
||||||
std::unique_lock<std::mutex> lock(mtx_);
|
std::unique_lock<std::mutex> lock(mtx_);
|
||||||
try {
|
try {
|
||||||
T value = std::any_cast<T>(key_to_meta_[name]);
|
T value = std::any_cast<T>(key_to_meta_[name]);
|
||||||
|
@ -71,10 +71,10 @@ class LocalMetaStorage {
|
||||||
const size_t curr_iter_num();
|
const size_t curr_iter_num();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
LocalMetaStorage() = default;
|
LocalMetaStore() = default;
|
||||||
~LocalMetaStorage() = default;
|
~LocalMetaStore() = default;
|
||||||
LocalMetaStorage(const LocalMetaStorage &) = delete;
|
LocalMetaStore(const LocalMetaStore &) = delete;
|
||||||
LocalMetaStorage &operator=(const LocalMetaStorage &) = delete;
|
LocalMetaStore &operator=(const LocalMetaStore &) = delete;
|
||||||
|
|
||||||
// key_to_meta_ stores metadata with key-value format.
|
// key_to_meta_ stores metadata with key-value format.
|
||||||
std::unordered_map<std::string, std::any> key_to_meta_;
|
std::unordered_map<std::string, std::any> key_to_meta_;
|
||||||
|
@ -85,4 +85,4 @@ class LocalMetaStorage {
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace ps
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORAGE_H_
|
#endif // MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORE_H_
|
|
@ -0,0 +1,144 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "ps/server/model_store.h"
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include "ps/server/executor.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ps {
|
||||||
|
namespace server {
|
||||||
|
void ModelStore::Init(uint32_t max_count) {
|
||||||
|
if (!Executor::GetInstance().initialized()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Server's executor must be initialized before model storage.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
max_model_count_ = max_count;
|
||||||
|
iteration_to_model_[kInitIterationNum] = AssignNewModelMemory();
|
||||||
|
model_size_ = ComputeModelSize();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ModelStore::StoreModelByIterNum(size_t iteration, const std::map<std::string, AddressPtr> &new_model) {
|
||||||
|
if (iteration_to_model_.count(iteration) != 0) {
|
||||||
|
MS_LOG(WARNING) << "Model for iteration " << iteration << " is already stored";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (new_model.empty()) {
|
||||||
|
MS_LOG(ERROR) << "Model feature map is empty.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<MemoryRegister> memory_register;
|
||||||
|
if (iteration_to_model_.size() < max_model_count_) {
|
||||||
|
// If iteration_to_model_.size() is not max_model_count_, need to assign new memory for the model.
|
||||||
|
memory_register = AssignNewModelMemory();
|
||||||
|
if (memory_register == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Memory for the new model is nullptr.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
iteration_to_model_[iteration] = memory_register;
|
||||||
|
} else {
|
||||||
|
// If iteration_to_model_ size is already max_model_count_, we need to replace earliest model with the newest model.
|
||||||
|
memory_register = iteration_to_model_.begin()->second;
|
||||||
|
if (memory_register == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Earliest model is nullptr.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
iteration_to_model_.erase(iteration_to_model_.begin());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy new model data to the the stored model.
|
||||||
|
auto &stored_model = memory_register->addresses();
|
||||||
|
for (const auto &weight : new_model) {
|
||||||
|
const std::string &weight_name = weight.first;
|
||||||
|
if (stored_model.count(weight_name) != 0) {
|
||||||
|
MS_LOG(ERROR) << "The stored model has no weight " << weight_name;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
void *dst_addr = stored_model[weight_name]->addr;
|
||||||
|
size_t dst_size = stored_model[weight_name]->size;
|
||||||
|
void *src_addr = weight.second->addr;
|
||||||
|
size_t src_size = weight.second->size;
|
||||||
|
int ret = memcpy_s(dst_addr, dst_size, src_addr, src_size);
|
||||||
|
if (ret != 0) {
|
||||||
|
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
iteration_to_model_[iteration] = memory_register;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::map<std::string, AddressPtr> ModelStore::GetModelByIterNum(size_t iteration) {
|
||||||
|
std::map<std::string, AddressPtr> model = {};
|
||||||
|
if (iteration_to_model_.count(iteration) == 0) {
|
||||||
|
MS_LOG(ERROR) << "Model for iteration " << iteration << " is not stored.";
|
||||||
|
return model;
|
||||||
|
}
|
||||||
|
model = iteration_to_model_[iteration]->addresses();
|
||||||
|
return model;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::map<size_t, std::shared_ptr<MemoryRegister>> &ModelStore::iteration_to_model() const {
|
||||||
|
return iteration_to_model_;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t ModelStore::model_size() const { return model_size_; }
|
||||||
|
|
||||||
|
std::shared_ptr<MemoryRegister> ModelStore::AssignNewModelMemory() {
|
||||||
|
std::map<std::string, AddressPtr> model = Executor::GetInstance().GetModel();
|
||||||
|
if (model.empty()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Model feature map is empty.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assign new memory for the model.
|
||||||
|
std::shared_ptr<MemoryRegister> memory_register = std::make_shared<MemoryRegister>();
|
||||||
|
for (const auto &weight : model) {
|
||||||
|
const std::string weight_name = weight.first;
|
||||||
|
size_t weight_size = weight.second->size;
|
||||||
|
auto weight_data = std::make_unique<char[]>(weight_size);
|
||||||
|
if (weight_data == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "Assign memory for weight failed.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
memory_register->RegisterArray(weight_name, &weight_data, weight_size);
|
||||||
|
}
|
||||||
|
return memory_register;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t ModelStore::ComputeModelSize() {
|
||||||
|
if (iteration_to_model_.empty()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Calculating model size failed: model for iteration 0 is not stored yet. ";
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto &model = iteration_to_model_[kInitIterationNum];
|
||||||
|
MS_EXCEPTION_IF_NULL(model);
|
||||||
|
size_t model_size = std::accumulate(model->addresses().begin(), model->addresses().end(), static_cast<size_t>(0),
|
||||||
|
[](size_t s, const auto &weight) { return s + weight.second->size; });
|
||||||
|
MS_LOG(INFO) << "Model size in byte is " << model_size;
|
||||||
|
return model_size;
|
||||||
|
}
|
||||||
|
} // namespace server
|
||||||
|
} // namespace ps
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,78 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_PS_SERVER_MODEL_STORE_H_
|
||||||
|
#define MINDSPORE_CCSRC_PS_SERVER_MODEL_STORE_H_
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include "ps/server/common.h"
|
||||||
|
#include "ps/server/memory_register.h"
|
||||||
|
#include "ps/server/executor.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ps {
|
||||||
|
namespace server {
|
||||||
|
// The initial iteration number is 0 in server.
|
||||||
|
constexpr size_t kInitIterationNum = 0;
|
||||||
|
|
||||||
|
// Server framework use ModelStore to store and query models.
|
||||||
|
// ModelStore stores multiple models because worker could get models of the previous iterations.
|
||||||
|
class ModelStore {
|
||||||
|
public:
|
||||||
|
static ModelStore &GetInstance() {
|
||||||
|
static ModelStore instance;
|
||||||
|
return instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize ModelStore with max count of models need to be stored.
|
||||||
|
void Init(uint32_t max_count = 3);
|
||||||
|
|
||||||
|
// Store the model of the given iteration. The model is acquired from Executor. If the current model count is already
|
||||||
|
// max_model_count_, the earliest model will be replaced.
|
||||||
|
bool StoreModelByIterNum(size_t iteration, const std::map<std::string, AddressPtr> &model);
|
||||||
|
|
||||||
|
// Get model of the given iteration.
|
||||||
|
std::map<std::string, AddressPtr> GetModelByIterNum(size_t iteration);
|
||||||
|
|
||||||
|
// Returns all models stored in ModelStore.
|
||||||
|
const std::map<size_t, std::shared_ptr<MemoryRegister>> &iteration_to_model() const;
|
||||||
|
|
||||||
|
// Returns the model size, which could be calculated at the initializing phase.
|
||||||
|
size_t model_size() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
ModelStore() : max_model_count_(0), model_size_(0), iteration_to_model_({}) {}
|
||||||
|
~ModelStore() = default;
|
||||||
|
ModelStore(const ModelStore &) = delete;
|
||||||
|
ModelStore &operator=(const ModelStore &) = delete;
|
||||||
|
|
||||||
|
// To store multiple models, new memory must assigned. The max memory size assigned for models is max_model_count_ *
|
||||||
|
// model_size_.
|
||||||
|
std::shared_ptr<MemoryRegister> AssignNewModelMemory();
|
||||||
|
|
||||||
|
// Calculate the model size. This method should be called after iteration_to_model_ is initialized.
|
||||||
|
size_t ComputeModelSize();
|
||||||
|
|
||||||
|
size_t max_model_count_;
|
||||||
|
size_t model_size_;
|
||||||
|
std::map<size_t, std::shared_ptr<MemoryRegister>> iteration_to_model_;
|
||||||
|
};
|
||||||
|
} // namespace server
|
||||||
|
} // namespace ps
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_PS_SERVER_MODEL_STORE_H_
|
|
@ -25,15 +25,15 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace ps {
|
||||||
namespace server {
|
namespace server {
|
||||||
bool ParameterAggregator::Init(const CNodePtr &cnode, size_t required_count) {
|
bool ParameterAggregator::Init(const CNodePtr &cnode, size_t threshold_count) {
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
memory_register_ = std::make_shared<MemoryRegister>();
|
memory_register_ = std::make_shared<MemoryRegister>();
|
||||||
MS_EXCEPTION_IF_NULL(memory_register_);
|
MS_EXCEPTION_IF_NULL(memory_register_);
|
||||||
|
|
||||||
required_push_count_ = required_count;
|
required_push_count_ = threshold_count;
|
||||||
// The required_pull_count_ is the count for Pull, which should be the same as required_push_count_.
|
// The required_pull_count_ is the count for Pull, which should be the same as required_push_count_.
|
||||||
// required_pull_count_ normally used in parameter server training mode.
|
// required_pull_count_ normally used in parameter server training mode.
|
||||||
required_pull_count_ = required_count;
|
required_pull_count_ = threshold_count;
|
||||||
|
|
||||||
MS_LOG(DEBUG) << "Start initializing kernels for " << AnfAlgo::GetCNodeName(cnode);
|
MS_LOG(DEBUG) << "Start initializing kernels for " << AnfAlgo::GetCNodeName(cnode);
|
||||||
InitAggregationKernels(cnode);
|
InitAggregationKernels(cnode);
|
||||||
|
|
|
@ -61,8 +61,8 @@ class ParameterAggregator {
|
||||||
~ParameterAggregator() = default;
|
~ParameterAggregator() = default;
|
||||||
|
|
||||||
// Initialize ParameterAggregator with a cnode. This cnode is normally a optimizer kernel for now.
|
// Initialize ParameterAggregator with a cnode. This cnode is normally a optimizer kernel for now.
|
||||||
// The parameter required_count helps ParameterAggregator to judge the current status if it's stateful.
|
// The parameter threshold_count helps ParameterAggregator to judge the current status if it's stateful.
|
||||||
bool Init(const CNodePtr &cnode, size_t required_count = 0);
|
bool Init(const CNodePtr &cnode, size_t threshold_count = 0);
|
||||||
|
|
||||||
// Update old data stored in ParameterAggregator with new data.
|
// Update old data stored in ParameterAggregator with new data.
|
||||||
// The data could have many meanings: weights, gradients, learning_rate, momentum, etc.
|
// The data could have many meanings: weights, gradients, learning_rate, momentum, etc.
|
||||||
|
|
|
@ -0,0 +1,139 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "ps/server/round.h"
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ps {
|
||||||
|
namespace server {
|
||||||
|
Round::Round(const std::string &name, bool check_timeout, size_t time_window, bool check_count, size_t threshold_count)
|
||||||
|
: name_(name),
|
||||||
|
check_timeout_(check_timeout),
|
||||||
|
time_window_(time_window),
|
||||||
|
check_count_(check_count),
|
||||||
|
threshold_count_(threshold_count) {}
|
||||||
|
|
||||||
|
void Round::Initialize(const std::shared_ptr<core::CommunicatorBase> &communicator, TimeOutCb timeout_cb,
|
||||||
|
FinishIterCb finish_iteration_cb) {
|
||||||
|
MS_EXCEPTION_IF_NULL(communicator);
|
||||||
|
communicator_ = communicator;
|
||||||
|
|
||||||
|
// Register callback for round kernel.
|
||||||
|
communicator_->RegisterMsgCallBack(
|
||||||
|
name_, [&](std::shared_ptr<core::MessageHandler> message) { LaunchRoundKernel(message); });
|
||||||
|
|
||||||
|
// Callback when the iteration is finished.
|
||||||
|
finish_iteration_cb_ = [this, finish_iteration_cb](void) -> void {
|
||||||
|
MS_LOG(INFO) << "Round " << name_ << " finished! Proceed to next iteration.";
|
||||||
|
finish_iteration_cb();
|
||||||
|
};
|
||||||
|
|
||||||
|
// Callback for finalizing the server. This can only be called once.
|
||||||
|
finalize_cb_ = [&](void) -> void { communicator_->Stop(); };
|
||||||
|
|
||||||
|
if (check_timeout_) {
|
||||||
|
iter_timer_ = std::make_shared<IterationTimer>();
|
||||||
|
|
||||||
|
// 1.Set the timeout callback for the timer.
|
||||||
|
iter_timer_->SetTimeOutCallBack([this, timeout_cb](void) -> void {
|
||||||
|
MS_LOG(INFO) << "Round " << name_ << " timeout! Proceed to next iteration.";
|
||||||
|
timeout_cb();
|
||||||
|
});
|
||||||
|
|
||||||
|
// 2.Stopping timer callback which will be set to the round kernel.
|
||||||
|
stop_timer_cb_ = [&](void) -> void {
|
||||||
|
MS_LOG(INFO) << "Round " << name_ << " kernel stops its timer.";
|
||||||
|
iter_timer_->Stop();
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set counter event callbacks for this round if the round kernel is stateful.
|
||||||
|
if (check_count_) {
|
||||||
|
auto first_count_handler = std::bind(&Round::OnFirstCountEvent, this, std::placeholders::_1);
|
||||||
|
auto last_count_handler = std::bind(&Round::OnLastCountEvent, this, std::placeholders::_1);
|
||||||
|
DistributedCountService::GetInstance().RegisterCounter(name_, threshold_count_,
|
||||||
|
{first_count_handler, last_count_handler});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Round::BindRoundKernel(const std::shared_ptr<kernel::RoundKernel> &kernel) {
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel);
|
||||||
|
kernel_ = kernel;
|
||||||
|
kernel_->set_stop_timer_cb(stop_timer_cb_);
|
||||||
|
kernel_->set_finish_iteration_cb(finish_iteration_cb_);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Round::LaunchRoundKernel(const std::shared_ptr<core::MessageHandler> &message) {
|
||||||
|
if (message == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Message is nullptr.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
AddressPtr input = std::make_shared<Address>();
|
||||||
|
AddressPtr output = std::make_shared<Address>();
|
||||||
|
input->addr = message->data();
|
||||||
|
input->size = message->len();
|
||||||
|
bool ret = kernel_->Launch({input}, {}, {output});
|
||||||
|
if (output->size == 0) {
|
||||||
|
std::string reason = "The output of the round " + name_ + " is empty.";
|
||||||
|
MS_LOG(WARNING) << reason;
|
||||||
|
communicator_->SendResponse(reason.c_str(), reason.size(), message);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Must send response back no matter what value Launch method returns.
|
||||||
|
if (!ret) {
|
||||||
|
MS_LOG(WARNING) << "Launching round kernel of round " << name_ << " failed.";
|
||||||
|
}
|
||||||
|
communicator_->SendResponse(output->addr, output->size, message);
|
||||||
|
kernel_->Release(output);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Round::Reset() { kernel_->Reset(); }
|
||||||
|
|
||||||
|
const std::string &Round::name() const { return name_; }
|
||||||
|
|
||||||
|
size_t Round::threshold_count() const { return threshold_count_; }
|
||||||
|
|
||||||
|
size_t Round::time_window() const { return time_window_; }
|
||||||
|
|
||||||
|
void Round::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &) {
|
||||||
|
MS_LOG(INFO) << "Round " << name_ << " first count event is triggered.";
|
||||||
|
// The timer starts only after the first count event is triggered by DistributedCountService.
|
||||||
|
if (check_timeout_) {
|
||||||
|
iter_timer_->Start(std::chrono::milliseconds(time_window_));
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Round::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) {
|
||||||
|
MS_LOG(INFO) << "Round " << name_ << " last count event is triggered.";
|
||||||
|
// Same as the first count event, the timer must be stopped by DistributedCountService.
|
||||||
|
if (check_timeout_) {
|
||||||
|
iter_timer_->Stop();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Some kernels override the OnLastCountEvent method.
|
||||||
|
kernel_->OnLastCountEvent(message);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
} // namespace server
|
||||||
|
} // namespace ps
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,95 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_PS_SERVER_ROUND_H_
|
||||||
|
#define MINDSPORE_CCSRC_PS_SERVER_ROUND_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include "ps/core/communicator/communicator_base.h"
|
||||||
|
#include "ps/server/common.h"
|
||||||
|
#include "ps/server/iteration_timer.h"
|
||||||
|
#include "ps/server/distributed_count_service.h"
|
||||||
|
#include "ps/server/kernel/round/round_kernel.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ps {
|
||||||
|
namespace server {
|
||||||
|
// Round helps server to handle network round messages and launch round kernels. One iteration in server consists of
|
||||||
|
// multiple rounds like startFLJob, updateModel, Push, Pull, etc. Some round kernels may be stateful because of counting
|
||||||
|
// and timing. So Round helps register counter and timer so that the round kernels only need to focus on the logic.
|
||||||
|
class Round {
|
||||||
|
public:
|
||||||
|
explicit Round(const std::string &name, bool check_timeout = true, size_t time_window = 3000,
|
||||||
|
bool check_count = false, size_t threshold_count = 8);
|
||||||
|
~Round() = default;
|
||||||
|
|
||||||
|
void Initialize(const std::shared_ptr<core::CommunicatorBase> &communicator, TimeOutCb timeout_cb,
|
||||||
|
FinishIterCb finish_iteration_cb);
|
||||||
|
|
||||||
|
// Bind a round kernel to this Round. This method should be called after Initialize.
|
||||||
|
void BindRoundKernel(const std::shared_ptr<kernel::RoundKernel> &kernel);
|
||||||
|
|
||||||
|
// This method is the callback which will be set to the communicator and called after the corresponding round message
|
||||||
|
// is sent to the server.
|
||||||
|
void LaunchRoundKernel(const std::shared_ptr<core::MessageHandler> &message);
|
||||||
|
|
||||||
|
// Round needs to be reset after each iteration is finished or its timer expires.
|
||||||
|
void Reset();
|
||||||
|
|
||||||
|
const std::string &name() const;
|
||||||
|
size_t threshold_count() const;
|
||||||
|
size_t time_window() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// The callbacks which will be set to DistributedCounterService.
|
||||||
|
void OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message);
|
||||||
|
void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message);
|
||||||
|
|
||||||
|
std::string name_;
|
||||||
|
|
||||||
|
// Whether this round needs to use timer. Most rounds in federated learning with mobile devices scenario need to set
|
||||||
|
// check_timeout_ to true.
|
||||||
|
bool check_timeout_;
|
||||||
|
|
||||||
|
// The time window duration for this round when check_timeout_ is set to true.
|
||||||
|
size_t time_window_;
|
||||||
|
|
||||||
|
// If check_count_ is true, it means the round has to do counting for every round message and the first/last count
|
||||||
|
// event will be triggered.
|
||||||
|
bool check_count_;
|
||||||
|
|
||||||
|
// The threshold count for this round when check_count_ is set to true. The logic of this round has to check whether
|
||||||
|
// the round message count has reached threshold_count_.
|
||||||
|
size_t threshold_count_;
|
||||||
|
|
||||||
|
std::shared_ptr<core::CommunicatorBase> communicator_;
|
||||||
|
|
||||||
|
// The round kernel for this Round.
|
||||||
|
std::shared_ptr<kernel::RoundKernel> kernel_;
|
||||||
|
|
||||||
|
// Some rounds may need timer to eliminate the long tail effect.
|
||||||
|
std::shared_ptr<IterationTimer> iter_timer_;
|
||||||
|
|
||||||
|
// The callbacks which will be set to the round kernel.
|
||||||
|
StopTimerCb stop_timer_cb_;
|
||||||
|
FinishIterCb finish_iteration_cb_;
|
||||||
|
FinalizeCb finalize_cb_;
|
||||||
|
};
|
||||||
|
} // namespace server
|
||||||
|
} // namespace ps
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_PS_SERVER_ROUND_H_
|
|
@ -31,7 +31,7 @@
|
||||||
#include "utils/utils.h"
|
#include "utils/utils.h"
|
||||||
#include "frontend/parallel/context.h"
|
#include "frontend/parallel/context.h"
|
||||||
#include "debug/env_config_parser.h"
|
#include "debug/env_config_parser.h"
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
#include "ps/ps_cache/ps_cache_manager.h"
|
#include "ps/ps_cache/ps_cache_manager.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -307,7 +307,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
|
||||||
}
|
}
|
||||||
need_alloc_nodes.push_back(item);
|
need_alloc_nodes.push_back(item);
|
||||||
}
|
}
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
bool ps_cache_check = false;
|
bool ps_cache_check = false;
|
||||||
#endif
|
#endif
|
||||||
for (auto &item : need_alloc_nodes) {
|
for (auto &item : need_alloc_nodes) {
|
||||||
|
@ -320,7 +320,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
DeviceAddressPtr device_address = nullptr;
|
DeviceAddressPtr device_address = nullptr;
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
const std::string ¶m_name = item->fullname_with_scope();
|
const std::string ¶m_name = item->fullname_with_scope();
|
||||||
if (ps::ps_cache_instance.IsHashTable(param_name)) {
|
if (ps::ps_cache_instance.IsHashTable(param_name)) {
|
||||||
MS_LOG(INFO) << "Parameter(" << param_name << ")"
|
MS_LOG(INFO) << "Parameter(" << param_name << ")"
|
||||||
|
@ -1038,7 +1038,7 @@ DeviceAddressPtr KernelRuntime::AssignSingleOpLaunchMemory(size_t size, const st
|
||||||
return device_address;
|
return device_address;
|
||||||
}
|
}
|
||||||
|
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph,
|
void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph,
|
||||||
AnfNodePtr *const first_cache_input_index,
|
AnfNodePtr *const first_cache_input_index,
|
||||||
size_t *const first_cache_size) {
|
size_t *const first_cache_size) {
|
||||||
|
|
|
@ -142,7 +142,7 @@ class KernelRuntime {
|
||||||
void RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value, session::KernelGraph *graph);
|
void RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value, session::KernelGraph *graph);
|
||||||
void AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, size_t output_idx);
|
void AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, size_t output_idx);
|
||||||
DeviceAddressPtr PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index);
|
DeviceAddressPtr PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index);
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
void GetFirstPSEmbeddingCache(const session::KernelGraph *graph, AnfNodePtr *const first_cache_input_index,
|
void GetFirstPSEmbeddingCache(const session::KernelGraph *graph, AnfNodePtr *const first_cache_input_index,
|
||||||
size_t *const first_cache_size);
|
size_t *const first_cache_size);
|
||||||
void CheckIfSupportPSEmbeddingCache(const session::KernelGraph *graph);
|
void CheckIfSupportPSEmbeddingCache(const session::KernelGraph *graph);
|
||||||
|
|
|
@ -16,14 +16,14 @@
|
||||||
|
|
||||||
#include "runtime/device/kernel_runtime_manager.h"
|
#include "runtime/device/kernel_runtime_manager.h"
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
#include "ps/ps_cache/ps_cache_manager.h"
|
#include "ps/ps_cache/ps_cache_manager.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace device {
|
namespace device {
|
||||||
void KernelRuntimeManager::ClearRuntimeResource() {
|
void KernelRuntimeManager::ClearRuntimeResource() {
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
||||||
ps::ps_cache_instance.SyncEmbeddingTable();
|
ps::ps_cache_instance.SyncEmbeddingTable();
|
||||||
}
|
}
|
||||||
|
@ -125,7 +125,7 @@ void KernelRuntimeManager::ReleaseKernelRuntime(const std::string &device_name,
|
||||||
if (runtime == nullptr) {
|
if (runtime == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
||||||
ps::ps_cache_instance.SyncEmbeddingTable();
|
ps::ps_cache_instance.SyncEmbeddingTable();
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,123 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
namespace mindspore.schema;
|
||||||
|
|
||||||
|
table CipherPublicParams {
|
||||||
|
t:int;
|
||||||
|
p:[ubyte];
|
||||||
|
g:int;
|
||||||
|
prime:[ubyte];
|
||||||
|
dp_eps:float;
|
||||||
|
dp_delta:float;
|
||||||
|
dp_norm_clip:float;
|
||||||
|
encrypt_type:int;
|
||||||
|
}
|
||||||
|
|
||||||
|
table ClientPublicKeys {
|
||||||
|
fl_id:string;
|
||||||
|
c_pk:[ubyte];
|
||||||
|
s_pk: [ubyte];
|
||||||
|
}
|
||||||
|
|
||||||
|
table ClientShare {
|
||||||
|
fl_id:string;
|
||||||
|
share:[ubyte];
|
||||||
|
index:int;
|
||||||
|
}
|
||||||
|
|
||||||
|
table RequestExchangeKeys{
|
||||||
|
fl_id:string;
|
||||||
|
c_pk:[ubyte];
|
||||||
|
s_pk:[ubyte];
|
||||||
|
iteration:int;
|
||||||
|
timestamp:string;
|
||||||
|
}
|
||||||
|
|
||||||
|
table ResponseExchangeKeys{
|
||||||
|
retcode:int;
|
||||||
|
reason:string;
|
||||||
|
next_req_time:string;
|
||||||
|
iteration:int;
|
||||||
|
}
|
||||||
|
|
||||||
|
table GetExchangeKeys{
|
||||||
|
fl_id:string;
|
||||||
|
iteration:int;
|
||||||
|
timestamp:string;
|
||||||
|
}
|
||||||
|
|
||||||
|
table ReturnExchangeKeys{
|
||||||
|
retcode:int;
|
||||||
|
iteration:int;
|
||||||
|
remote_publickeys:[ClientPublicKeys];
|
||||||
|
next_req_time:string;
|
||||||
|
}
|
||||||
|
|
||||||
|
table RequestShareSecrets{
|
||||||
|
fl_id:string;
|
||||||
|
encrypted_shares:[ClientShare];
|
||||||
|
iteration:int;
|
||||||
|
timestamp:string;
|
||||||
|
}
|
||||||
|
|
||||||
|
table ResponseShareSecrets{
|
||||||
|
retcode:int;
|
||||||
|
reason:string;
|
||||||
|
next_req_time:string;
|
||||||
|
iteration:int;
|
||||||
|
}
|
||||||
|
|
||||||
|
table GetShareSecrets{
|
||||||
|
fl_id:string;
|
||||||
|
iteration:int;
|
||||||
|
timestamp:string;
|
||||||
|
}
|
||||||
|
|
||||||
|
table ReturnShareSecrets{
|
||||||
|
retcode:int;
|
||||||
|
iteration:int;
|
||||||
|
encrypted_shares: [ClientShare];
|
||||||
|
next_req_time:string;
|
||||||
|
}
|
||||||
|
|
||||||
|
table GetClientList{
|
||||||
|
fl_id:string;
|
||||||
|
iteration:int;
|
||||||
|
timestamp:string;
|
||||||
|
}
|
||||||
|
|
||||||
|
table ReturnClientList{
|
||||||
|
retcode:int;
|
||||||
|
reason:string;
|
||||||
|
clients:[string];
|
||||||
|
iteration:int;
|
||||||
|
next_req_time:string;
|
||||||
|
}
|
||||||
|
|
||||||
|
table SendReconstructSecret{
|
||||||
|
fl_id:string;
|
||||||
|
reconstruct_secret_shares:[ClientShare];
|
||||||
|
iteration:int;
|
||||||
|
timestamp:string;
|
||||||
|
}
|
||||||
|
|
||||||
|
table ReconstructSecret{
|
||||||
|
retcode:int;
|
||||||
|
reason:string;
|
||||||
|
iteration:int;
|
||||||
|
next_req_time:string;
|
||||||
|
}
|
|
@ -0,0 +1,159 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
include "cipher.fbs";
|
||||||
|
namespace mindspore.schema;
|
||||||
|
|
||||||
|
file_identifier "FLJ0";
|
||||||
|
file_extension "fl";
|
||||||
|
|
||||||
|
enum ResponseCode: int {
|
||||||
|
SUCCEED=200,
|
||||||
|
SucNotReady=201,
|
||||||
|
RepeatRequest=202,
|
||||||
|
SucNotMatch=204,
|
||||||
|
OutOfTime=300,
|
||||||
|
NotSelected=301,
|
||||||
|
RequestError=400,
|
||||||
|
SystemError=500
|
||||||
|
}
|
||||||
|
|
||||||
|
enum AggregationType:byte {FedAvg=0, FedAdam = 1, FedAdagrag=2, FedMeta=3, qffl=4}
|
||||||
|
enum Metrics:byte {accuracy = 0, precision = 1, recall = 2, AUC = 3,f1=4, fbeta=5}
|
||||||
|
enum EarlyStopType:byte {loss_diff = 0, loss_abs = 1, weight_diff = 2}
|
||||||
|
|
||||||
|
table Aggregation {
|
||||||
|
type:AggregationType;
|
||||||
|
weights:[float];
|
||||||
|
}
|
||||||
|
|
||||||
|
table EarlyStop {
|
||||||
|
early_stop_type:EarlyStopType;
|
||||||
|
weight:float;
|
||||||
|
rounds:int;
|
||||||
|
}
|
||||||
|
|
||||||
|
table FeatureMap{
|
||||||
|
weight_fullname:string;
|
||||||
|
data:[float];
|
||||||
|
}
|
||||||
|
table RequestFLJob{
|
||||||
|
fl_name:string;
|
||||||
|
fl_id:string;
|
||||||
|
iteration:int;
|
||||||
|
data_size:int;
|
||||||
|
timestamp:string;
|
||||||
|
}
|
||||||
|
table ResponseFLJob {
|
||||||
|
retcode:int;
|
||||||
|
reason:string;
|
||||||
|
iteration:int;
|
||||||
|
is_selected:bool = false;
|
||||||
|
next_req_time:string;
|
||||||
|
fl_plan_config:FLPlan;
|
||||||
|
feature_map:[FeatureMap];
|
||||||
|
timestamp:string;
|
||||||
|
}
|
||||||
|
|
||||||
|
table FLPlan {
|
||||||
|
fl_name:string;
|
||||||
|
iterations:int;
|
||||||
|
epochs:int;
|
||||||
|
early_stop:EarlyStop;
|
||||||
|
mini_batch:int;
|
||||||
|
shuffle:bool = false;
|
||||||
|
lr:float;
|
||||||
|
aggregation:Aggregation;
|
||||||
|
metrics:[Metrics];
|
||||||
|
cipher:CipherPublicParams;
|
||||||
|
}
|
||||||
|
|
||||||
|
table RequestUpdateModel{
|
||||||
|
fl_name:string;
|
||||||
|
fl_id:string;
|
||||||
|
iteration:int;
|
||||||
|
feature_map:[FeatureMap];
|
||||||
|
timestamp:string;
|
||||||
|
}
|
||||||
|
table ResponseUpdateModel{
|
||||||
|
retcode:int;
|
||||||
|
reason:string;
|
||||||
|
feature_map:[FeatureMap];
|
||||||
|
next_req_time:string;
|
||||||
|
timestamp:string;
|
||||||
|
}
|
||||||
|
|
||||||
|
table RequestAsyncUpdateModel{
|
||||||
|
fl_name:string;
|
||||||
|
fl_id:string;
|
||||||
|
iteration:int;
|
||||||
|
data_size:int;
|
||||||
|
feature_map:[FeatureMap];
|
||||||
|
}
|
||||||
|
table ResponseAsyncUpdateModel{
|
||||||
|
retcode:int;
|
||||||
|
reason:string;
|
||||||
|
iteration:int;
|
||||||
|
}
|
||||||
|
|
||||||
|
table RequestOverwriteWeightsByKey{
|
||||||
|
iteration:int;
|
||||||
|
feature_map:[FeatureMap];
|
||||||
|
}
|
||||||
|
|
||||||
|
table ResponseOverwriteWeightsByKey{
|
||||||
|
retcode:int;
|
||||||
|
reason:string;
|
||||||
|
}
|
||||||
|
|
||||||
|
table RequestGetModel{
|
||||||
|
fl_name:string;
|
||||||
|
iteration:int;
|
||||||
|
timestamp:string;
|
||||||
|
}
|
||||||
|
table ResponseGetModel{
|
||||||
|
retcode:int;
|
||||||
|
reason:string;
|
||||||
|
iteration:int;
|
||||||
|
feature_map:[FeatureMap];
|
||||||
|
timestamp:string;
|
||||||
|
}
|
||||||
|
|
||||||
|
table RequestAsyncGetModel{
|
||||||
|
fl_name:string;
|
||||||
|
iteration:int;
|
||||||
|
}
|
||||||
|
table ResponseAsyncGetModel{
|
||||||
|
retcode:int;
|
||||||
|
reason:string;
|
||||||
|
iteration:int;
|
||||||
|
feature_map:[FeatureMap];
|
||||||
|
}
|
||||||
|
|
||||||
|
table RequestGetWeightsByKey{
|
||||||
|
iteration:int;
|
||||||
|
weight_names:[string];
|
||||||
|
}
|
||||||
|
table ResponseGetWeightsByKey{
|
||||||
|
retcode:int;
|
||||||
|
reason:string;
|
||||||
|
feature_map:[FeatureMap];
|
||||||
|
}
|
||||||
|
|
||||||
|
// FeatureMapList refers to the whole trained model.
|
||||||
|
table FeatureMapList {
|
||||||
|
feature_map:[FeatureMap];
|
||||||
|
}
|
Loading…
Reference in New Issue