From 232e525ff2dd767c91a162e09b3a9b05dda77204 Mon Sep 17 00:00:00 2001 From: ZPaC Date: Sat, 20 Feb 2021 09:23:48 +0800 Subject: [PATCH] Add common.h, local_meta_storage.h, param_name_type.h Add class Message, AddrStorage Add Kernel folder Add consistent hash ring Add updater for server. Add iteration timer Add executor for server --- mindspore/ccsrc/ps/CMakeLists.txt | 14 + mindspore/ccsrc/ps/server/common.h | 189 +++++++++++ .../ccsrc/ps/server/consistent_hash_ring.cc | 58 ++++ .../ccsrc/ps/server/consistent_hash_ring.h | 64 ++++ mindspore/ccsrc/ps/server/executor.cc | 315 +++++++++++++++++ mindspore/ccsrc/ps/server/executor.h | 125 +++++++ mindspore/ccsrc/ps/server/iteration_timer.cc | 56 +++ mindspore/ccsrc/ps/server/iteration_timer.h | 64 ++++ .../ps/server/kernel/aggregation_kernel.h | 95 ++++++ .../kernel/aggregation_kernel_factory.cc | 71 ++++ .../kernel/aggregation_kernel_factory.h | 71 ++++ .../ps/server/kernel/apply_momentum_kernel.cc | 34 ++ .../ps/server/kernel/apply_momentum_kernel.h | 61 ++++ .../server/kernel/dense_grad_accum_kernel.cc | 30 ++ .../server/kernel/dense_grad_accum_kernel.h | 95 ++++++ .../ccsrc/ps/server/kernel/kernel_factory.h | 92 +++++ .../ccsrc/ps/server/kernel/optimizer_kernel.h | 97 ++++++ .../server/kernel/optimizer_kernel_factory.cc | 70 ++++ .../server/kernel/optimizer_kernel_factory.h | 64 ++++ .../ccsrc/ps/server/kernel/params_info.cc | 68 ++++ .../ccsrc/ps/server/kernel/params_info.h | 70 ++++ .../ccsrc/ps/server/local_meta_storage.cc | 46 +++ .../ccsrc/ps/server/local_meta_storage.h | 88 +++++ mindspore/ccsrc/ps/server/memory_register.cc | 34 ++ mindspore/ccsrc/ps/server/memory_register.h | 88 +++++ .../ccsrc/ps/server/parameter_aggregator.cc | 321 ++++++++++++++++++ .../ccsrc/ps/server/parameter_aggregator.h | 139 ++++++++ tests/ut/cpp/CMakeLists.txt | 1 + 28 files changed, 2520 insertions(+) create mode 100644 mindspore/ccsrc/ps/server/common.h create mode 100644 mindspore/ccsrc/ps/server/consistent_hash_ring.cc create mode 100644 mindspore/ccsrc/ps/server/consistent_hash_ring.h create mode 100644 mindspore/ccsrc/ps/server/executor.cc create mode 100644 mindspore/ccsrc/ps/server/executor.h create mode 100644 mindspore/ccsrc/ps/server/iteration_timer.cc create mode 100644 mindspore/ccsrc/ps/server/iteration_timer.h create mode 100644 mindspore/ccsrc/ps/server/kernel/aggregation_kernel.h create mode 100644 mindspore/ccsrc/ps/server/kernel/aggregation_kernel_factory.cc create mode 100644 mindspore/ccsrc/ps/server/kernel/aggregation_kernel_factory.h create mode 100644 mindspore/ccsrc/ps/server/kernel/apply_momentum_kernel.cc create mode 100644 mindspore/ccsrc/ps/server/kernel/apply_momentum_kernel.h create mode 100644 mindspore/ccsrc/ps/server/kernel/dense_grad_accum_kernel.cc create mode 100644 mindspore/ccsrc/ps/server/kernel/dense_grad_accum_kernel.h create mode 100644 mindspore/ccsrc/ps/server/kernel/kernel_factory.h create mode 100644 mindspore/ccsrc/ps/server/kernel/optimizer_kernel.h create mode 100644 mindspore/ccsrc/ps/server/kernel/optimizer_kernel_factory.cc create mode 100644 mindspore/ccsrc/ps/server/kernel/optimizer_kernel_factory.h create mode 100644 mindspore/ccsrc/ps/server/kernel/params_info.cc create mode 100644 mindspore/ccsrc/ps/server/kernel/params_info.h create mode 100644 mindspore/ccsrc/ps/server/local_meta_storage.cc create mode 100644 mindspore/ccsrc/ps/server/local_meta_storage.h create mode 100644 mindspore/ccsrc/ps/server/memory_register.cc create mode 100644 mindspore/ccsrc/ps/server/memory_register.h create mode 100644 mindspore/ccsrc/ps/server/parameter_aggregator.cc create mode 100644 mindspore/ccsrc/ps/server/parameter_aggregator.h diff --git a/mindspore/ccsrc/ps/CMakeLists.txt b/mindspore/ccsrc/ps/CMakeLists.txt index bdf008ca881..a7752af1585 100644 --- a/mindspore/ccsrc/ps/CMakeLists.txt +++ b/mindspore/ccsrc/ps/CMakeLists.txt @@ -39,6 +39,20 @@ if(NOT ENABLE_GPU) list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/gpu/gpu_ps_cache.cc") endif() +if(WIN32 OR NOT ENABLE_CPU) + list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/apply_momentum_kernel.cc") + list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/aggregation_kernel_factory.cc") + list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/dense_grad_accum_kernel.cc") + list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/optimizer_kernel_factory.cc") + list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/params_info.cc") + list(REMOVE_ITEM _PS_SRC_FILES "server/consistent_hash_ring.cc") + list(REMOVE_ITEM _PS_SRC_FILES "server/iteration_timer.cc") + list(REMOVE_ITEM _PS_SRC_FILES "server/local_meta_storage.cc") + list(REMOVE_ITEM _PS_SRC_FILES "server/memory_register.cc") + list(REMOVE_ITEM _PS_SRC_FILES "server/parameter_aggregator.cc") + list(REMOVE_ITEM _PS_SRC_FILES "server/executor.cc") +endif() + list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_prefetch.cc") list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_channel.cc") add_subdirectory(ps_cache) diff --git a/mindspore/ccsrc/ps/server/common.h b/mindspore/ccsrc/ps/server/common.h new file mode 100644 index 00000000000..053f00024e9 --- /dev/null +++ b/mindspore/ccsrc/ps/server/common.h @@ -0,0 +1,189 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_SERVER_COMMON_H_ +#define MINDSPORE_CCSRC_PS_SERVER_COMMON_H_ + +#include +#include +#include +#include +#include +#include +#include "proto/ps.pb.h" +#include "ir/anf.h" +#include "utils/utils.h" +#include "ir/dtype/type_id.h" +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "ps/ps_context.h" +#include "ps/core/communicator/http_message_handler.h" +#include "ps/core/communicator/tcp_server.h" + +namespace mindspore { +namespace ps { +namespace server { +// Definitions for the server framework. +enum ServerMode { PARAMETER_SERVER = 0, FL_SERVER }; +enum CommType { HTTP = 0, TCP }; +enum AggregationType { FedAvg = 0, FedAdam, FedAdagarg, FedMeta, qffl, DenseGradAccum, SparseGradAccum }; + +using kernel::Address; +using kernel::AddressPtr; +using kernel::CPUKernel; +using TimeOutCb = std::function; +using StopTimerCb = std::function; +using FinishIterCb = std::function; +using FinalizeCb = std::function; + +// Information about whether server kernel will reuse kernel node memory from the front end. +// Key refers to the server kernel's parameter name, like "weights", "grad", "learning_rate". +// Value refers to the kernel node's parameter index. +using ReuseKernelNodeInfo = std::map; + +// UploadData refers to the data which is uploaded by workers. +// Key refers to the data name. For example: "weights", "grad", "learning_rate", etc. This will be set by the worker. +// Value refers to the data of the key. + +// We use Address instead of AddressPtr because: +// 1. Address doesn't need to call make_shared so it has better performance. +// 2. The data uploaded by worker is normally parsed from FlatterBuffers or ProtoBuffer. For example: learning rate, new +// weights, etc. Address is enough to store these data. + +// Pay attention that Address only stores the void* pointer of the data, so the data must not be released before the +// related logic is done. +using UploadData = std::map; + +constexpr auto kWeight = "weight"; +constexpr auto kAccumulation = "accum"; +constexpr auto kLearningRate = "lr"; +constexpr auto kGradient = "grad"; +constexpr auto kNewGradient = "new_grad"; +constexpr auto kMomentum = "momentum"; +constexpr auto kIndices = "indices"; +constexpr auto kAdamM = "m"; +constexpr auto kAdamV = "v"; +constexpr auto kAdamBeta1Power = "beta1_power"; +constexpr auto kAdamBeta2Power = "beta2_power"; +constexpr auto kAdamBeta1 = "beta1"; +constexpr auto kAdamBeta2 = "beta2"; +constexpr auto kAdamEps = "eps"; +constexpr auto kFtrlLinear = "linear"; + +// OptimParamNameToIndex represents every inputs/workspace/outputs parameter's offset when an optimizer kernel is +// launched. +using OptimParamNameToIndex = std::map>; +const OptimParamNameToIndex kMomentumNameToIdx = { + {"inputs", {{kWeight, 0}, {kAccumulation, 1}, {kLearningRate, 2}, {kGradient, 3}, {kMomentum, 4}}}, {"outputs", {}}}; +const OptimParamNameToIndex kAdamNameToIdx = {{"inputs", + {{kWeight, 0}, + {kAdamM, 1}, + {kAdamV, 2}, + {kAdamBeta1Power, 3}, + {kAdamBeta2Power, 4}, + {kLearningRate, 5}, + {kAdamBeta1, 6}, + {kAdamBeta2, 7}, + {kAdamEps, 8}, + {kGradient, 9}}}, + {"outputs", {}}}; +const OptimParamNameToIndex kSparseAdamNameToIdx = {{"inputs", + {{kWeight, 0}, + {kAdamM, 1}, + {kAdamV, 2}, + {kAdamBeta1Power, 3}, + {kAdamBeta2Power, 4}, + {kLearningRate, 5}, + {kAdamBeta1, 6}, + {kAdamBeta1, 7}, + {kAdamEps, 8}, + {kGradient, 9}, + {kIndices, 10}}}, + {"outputs", {}}}; +const OptimParamNameToIndex kSparseFtrlNameToIdx = { + {"inputs", {{kWeight, 0}, {kAccumulation, 1}, {kFtrlLinear, 2}, {kGradient, 3}, {kIndices, 4}}}, {"outputs", {}}}; +const std::map kNameToIdxMap = { + {kApplyMomentumOpName, kMomentumNameToIdx}, + {kFusedSparseAdamName, kSparseAdamNameToIdx}, + {kSparseApplyFtrlOpName, kSparseFtrlNameToIdx}, + {kApplyAdamOpName, kAdamNameToIdx}, +}; + +constexpr uint32_t kLeaderServerRank = 0; +constexpr size_t kWorkerMgrThreadPoolSize = 32; +constexpr size_t kWorkerMgrMaxTaskNum = 64; +constexpr size_t kCipherMgrThreadPoolSize = 32; +constexpr size_t kCipherMgrMaxTaskNum = 64; +constexpr size_t kExecutorThreadPoolSize = 32; +constexpr size_t kExecutorMaxTaskNum = 32; +constexpr int kHttpSuccess = 200; +constexpr auto kPBProtocol = "PB"; +constexpr auto kFBSProtocol = "FBS"; +constexpr auto kAggregationKernelType = "Aggregation"; +constexpr auto kOptimizerKernelType = "Optimizer"; +constexpr auto kCtxFuncGraph = "FuncGraph"; +constexpr auto kCtxIterNum = "iteration"; +constexpr auto kCtxDeviceMetas = "device_metas"; +constexpr auto kCtxTotalTimeoutDuration = "total_timeout_duration"; +constexpr auto kCtxUpdateModelClientList = "update_model_client_list"; +constexpr auto kCtxUpdateModelClientNum = "update_model_client_num"; + +// This macro the current timestamp in milliseconds. +#define CURRENT_TIME_MILLI \ + std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()) + +#define RETURN_IF_NULL(expr, ret) \ + if (expr == nullptr) { \ + MS_LOG(ERROR) << #expr << " is nullptr."; \ + return ret; \ + } + +// This method returns the size in bytes of the given TypeId. +inline size_t GetTypeIdByte(const TypeId &type) { + switch (type) { + case kNumberTypeFloat16: + return 2; + case kNumberTypeUInt32: + case kNumberTypeFloat32: + return 4; + case kNumberTypeUInt64: + return 8; + default: + MS_LOG(EXCEPTION) << "TypeId " << type << " not supported."; + return 0; + } +} + +inline AddressPtr GenerateParameterNodeAddrPtr(const CNodePtr &kernel_node, size_t param_idx) { + RETURN_IF_NULL(kernel_node, nullptr); + auto param_node = + AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(kernel_node, param_idx), 0).first->cast(); + RETURN_IF_NULL(param_node, nullptr); + auto param_tensor = param_node->default_param()->cast(); + RETURN_IF_NULL(param_tensor, nullptr); + AddressPtr addr = std::make_shared(); + addr->addr = param_tensor->data_c(); + addr->size = param_tensor->data().nbytes(); + return addr; +} + +// Definitions for Federated Learning. + +// Definitions for Parameter Server. + +} // namespace server +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_SERVER_COMMON_H_ diff --git a/mindspore/ccsrc/ps/server/consistent_hash_ring.cc b/mindspore/ccsrc/ps/server/consistent_hash_ring.cc new file mode 100644 index 00000000000..7b6c4a746b3 --- /dev/null +++ b/mindspore/ccsrc/ps/server/consistent_hash_ring.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/server/consistent_hash_ring.h" + +namespace mindspore { +namespace ps { +namespace server { +bool ConsistentHashRing::Insert(uint32_t rank) { + std::string physical_node_hash_key = std::to_string(rank); + for (uint32_t i = 0; i < virtual_node_num_; i++) { + physical_node_hash_key += "#" + std::to_string(i); + MS_LOG(DEBUG) << "Insert virtual node " << physical_node_hash_key << " for node " << rank; + + size_t hash_value = std::hash()(physical_node_hash_key); + if (ring_.count(hash_value) != 0) { + MS_LOG(WARNING) << "Virtual node " << physical_node_hash_key << " is already mapped to the ring."; + continue; + } + ring_[hash_value] = rank; + } + return true; +} + +bool ConsistentHashRing::Erase(uint32_t rank) { + for (auto iterator = ring_.begin(); iterator != ring_.end();) { + if (iterator->second == rank) { + ring_.erase(iterator++); + } + } + return true; +} + +uint32_t ConsistentHashRing::Find(const std::string &key) { + size_t hash_value = std::hash()(key); + auto iterator = ring_.lower_bound(hash_value); + if (iterator == ring_.end()) { + // If the virtual node is not found clockwise, the key will be mapped to the first virtual node on the ring. + iterator = ring_.begin(); + } + return iterator->second; +} +} // namespace server +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/server/consistent_hash_ring.h b/mindspore/ccsrc/ps/server/consistent_hash_ring.h new file mode 100644 index 00000000000..c3e6d5223bb --- /dev/null +++ b/mindspore/ccsrc/ps/server/consistent_hash_ring.h @@ -0,0 +1,64 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_SERVER_CONSISTENT_HASH_RING_H_ +#define MINDSPORE_CCSRC_PS_SERVER_CONSISTENT_HASH_RING_H_ + +#include +#include +#include "utils/log_adapter.h" + +namespace mindspore { +namespace ps { +namespace server { +// To support distributed storage and make servers easy to scale-out and scale-in for a large load of metadata in +// server, we use class ConsistentHashRing to help servers find out which metadata is stored in which server node. + +// Class ConsistentHashRing implements the algorithm described in the paper +// . + +// This class will create a ring for hash values of metadata and server nodes. Each server could use this ring to +// retrieve data stored in other servers according to the hash keys. The time complexity for adding/deleting/searching +// of this algorithm is basically O(log n). +class ConsistentHashRing { + public: + // The parameter virtual_node_num for constructor means the virtual node number to be created for each physical server + // node. According to the paper, these virtual nodes could help spread data to all the servers and ensuring balancing + // at the same time. And when we say "adding/deleting/searching", we are talking about operations on thease virtual + // nodes instead of the physical nodes. + explicit ConsistentHashRing(uint32_t virtual_node_num = 128) : virtual_node_num_(virtual_node_num) {} + ~ConsistentHashRing() = default; + + // Insert several virtual nodes for a server into this ring according to its rank id. + bool Insert(uint32_t rank); + + // Remove virtual nodes for a server according to its rank id. + bool Erase(uint32_t rank); + + // Find the physical server node's rank according to the metadata's key. + uint32_t Find(const std::string &key); + + private: + uint32_t virtual_node_num_; + // The hash ring for the server nodes. + // Key is the hash value of the virtual node. + // Value is the physical node' rank id. + std::map ring_; +}; +} // namespace server +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_SERVER_CONSISTENT_HASH_RING_H_ diff --git a/mindspore/ccsrc/ps/server/executor.cc b/mindspore/ccsrc/ps/server/executor.cc new file mode 100644 index 00000000000..411ca988531 --- /dev/null +++ b/mindspore/ccsrc/ps/server/executor.cc @@ -0,0 +1,315 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/server/executor.h" +#include +#include +#include +#include + +namespace mindspore { +namespace ps { +namespace server { +void Executor::Init(const FuncGraphPtr &func_graph, size_t aggregation_count) { + MS_EXCEPTION_IF_NULL(func_graph); + if (aggregation_count == 0) { + MS_LOG(EXCEPTION) << "Server aggregation count must be greater than 0"; + return; + } + aggregation_count_ = aggregation_count; + + // Initialize each trainable parameter's aggregator, including memory register, aggregation algorithms and optimizers. + bool ret = InitParamAggregator(func_graph); + if (!ret) { + MS_LOG(EXCEPTION) << "Initializing parameter aggregators failed."; + return; + } + initialized_ = true; + return; +} + +bool Executor::initialized() const { return initialized_; } + +bool Executor::HandlePush(const std::string ¶m_name, const UploadData &upload_data) { + MS_LOG(DEBUG) << "Do Push for parameter " << param_name; + if (param_aggrs_.count(param_name) == 0) { + MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server."; + return false; + } + + std::mutex &mtx = parameter_mutex_[param_name]; + std::unique_lock lock(mtx); + auto ¶m_aggr = param_aggrs_[param_name]; + + // Push operation needs to wait until the pulling process is done. + while (!param_aggr->IsPullingDone()) { + lock.unlock(); + std::this_thread::sleep_for(std::chrono::milliseconds(5)); + lock.lock(); + } + + // 1.Update data with the uploaded data of the worker. + if (!param_aggr->UpdateData(upload_data)) { + MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed."; + return false; + } + // 2.Launch aggregation for this trainable parameter. + if (!param_aggr->LaunchAggregators()) { + MS_LOG(ERROR) << "Launching aggregators for parameter " << param_name << " failed."; + return false; + } + if (param_aggr->IsAggregationDone()) { + // 3.After the aggregation is done, optimize the trainable parameter. + if (!param_aggr->LaunchOptimizers()) { + MS_LOG(ERROR) << "Optimizing for parameter " << param_name << " failed."; + return false; + } + // 4.Reset pulling and aggregation status after optimizing is done. + param_aggr->ResetPullingStatus(); + param_aggr->ResetAggregationStatus(); + } + return true; +} + +bool Executor::HandleModelUpdate(const std::string ¶m_name, const UploadData &upload_data) { + MS_LOG(DEBUG) << "Do UpdateModel for parameter " << param_name; + if (param_aggrs_.count(param_name) == 0) { + // The param_name could include some other parameters like momentum, but we don't think it's invalid. So here we + // just print a warning log and return true. + MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server."; + return true; + } + + std::mutex &mtx = parameter_mutex_[param_name]; + std::unique_lock lock(mtx); + auto ¶m_aggr = param_aggrs_[param_name]; + + if (!param_aggr->UpdateData(upload_data)) { + MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed."; + return false; + } + // Different from Push, UpdateModel doesn't need to checkout the aggregation status. + if (!param_aggr->LaunchAggregators()) { + MS_LOG(ERROR) << "Launching aggregators for parameter " << param_name << " failed."; + return false; + } + return true; +} + +bool Executor::HandleModelUpdateAsync(const std::map &feature_map) { + std::unique_lock model_lock(model_mutex_); + for (const auto &trainable_param : feature_map) { + const std::string ¶m_name = trainable_param.first; + if (param_aggrs_.count(param_name) == 0) { + MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server."; + continue; + } + + std::mutex &mtx = parameter_mutex_[param_name]; + std::unique_lock lock(mtx); + auto ¶m_aggr = param_aggrs_[param_name]; + + const UploadData &upload_data = trainable_param.second; + if (!param_aggr->UpdateData(upload_data)) { + MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed."; + return false; + } + if (!param_aggr->LaunchAggregators()) { + MS_LOG(ERROR) << "Launching aggregators for parameter " << param_name << " failed."; + return false; + } + } + return true; +} + +bool Executor::HandleOverwriteWeightsByKey(const std::map &feature_map) { + for (const auto &trainable_param : feature_map) { + const std::string ¶m_name = trainable_param.first; + if (param_aggrs_.count(param_name) == 0) { + MS_LOG(WARNING) << "Weight " << param_name << " is not registered in server."; + continue; + } + + std::mutex &mtx = parameter_mutex_[param_name]; + std::unique_lock lock(mtx); + auto ¶m_aggr = param_aggrs_[param_name]; + + AddressPtr old_weight = param_aggr->GetWeight(); + if (old_weight == nullptr) { + MS_LOG(ERROR) << "Get weight of " << param_name << " failed: the AddressPtr is nullptr."; + return false; + } + + const Address &new_weight = trainable_param.second; + if (new_weight.addr == nullptr) { + MS_LOG(ERROR) << "The new weight is nullptr."; + return false; + } + + int ret = memcpy_s(old_weight->addr, old_weight->size, new_weight.addr, new_weight.size); + if (ret != 0) { + MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; + return false; + } + } + return true; +} + +AddressPtr Executor::HandlePull(const std::string ¶m_name) { + MS_LOG(INFO) << "Handle blocking pull msg for parameter " << param_name; + if (param_aggrs_.count(param_name) == 0) { + MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server."; + return nullptr; + } + + std::mutex &mtx = parameter_mutex_[param_name]; + std::unique_lock lock(mtx); + auto ¶m_aggr = param_aggrs_[param_name]; + + // Pulling must wait until the optimizing process is done. + while (!param_aggr->IsOptimizingDone()) { + lock.unlock(); + std::this_thread::sleep_for(std::chrono::milliseconds(5)); + lock.lock(); + } + AddressPtr addr = param_aggr->Pull(); + // If this Pull is the last one, reset pulling and optimizing status. + if (param_aggr->IsPullingDone()) { + param_aggr->ResetOptimizingStatus(); + } + return addr; +} + +std::map Executor::HandleAsyncGetModel() { + std::unique_lock lock(model_mutex_); + return GetModel(); +} + +std::map Executor::HandleGetWeightsByKey(const std::vector ¶m_names) { + std::map weights; + for (const auto ¶m_name : param_names) { + if (param_aggrs_.count(param_name) == 0) { + MS_LOG(ERROR) << "Parameter " << param_name << " is not registered in server."; + return weights; + } + + std::mutex &mtx = parameter_mutex_[param_name]; + std::unique_lock lock(mtx); + const auto ¶m_aggr = param_aggrs_[param_name]; + + AddressPtr addr = param_aggr->GetWeight(); + if (addr == nullptr) { + MS_LOG(ERROR) << "Get weight of " << param_name << " failed: the AddressPtr is nullptr."; + continue; + } + weights[param_name] = addr; + } + return weights; +} + +bool Executor::IsAllWeightAggregationDone() { return IsWeightAggrDone(param_names_); } + +bool Executor::IsWeightAggrDone(const std::vector ¶m_names) { + for (const auto &name : param_names) { + if (param_aggrs_.count(name) == 0) { + MS_LOG(ERROR) << "Weight " << name << " is invalid in server."; + return false; + } + + std::mutex &mtx = parameter_mutex_[name]; + std::unique_lock lock(mtx); + if (!param_aggrs_[name]->IsAggregationDone()) { + MS_LOG(DEBUG) << "Update model for " << name << " is not done yet."; + return false; + } + } + return true; +} + +void Executor::ResetAggregationStatus() { + for (const auto ¶m_name : param_names_) { + std::mutex &mtx = parameter_mutex_[param_name]; + std::unique_lock lock(mtx); + param_aggrs_[param_name]->ResetAggregationStatus(); + } + return; +} + +std::map Executor::GetModel() { + std::map model = {}; + for (const auto &name : param_names_) { + std::mutex &mtx = parameter_mutex_[name]; + std::unique_lock lock(mtx); + AddressPtr addr = param_aggrs_[name]->GetWeight(); + if (addr == nullptr) { + MS_LOG(WARNING) << "Get weight of " << name << " failed."; + continue; + } + model[name] = addr; + } + return model; +} + +// bool Executor::Unmask() { +// auto model = GetModel(); +// return mindarmour::CipherMgr::GetInstance().UnMask(model); +// } + +const std::vector &Executor::param_names() const { return param_names_; } + +std::string Executor::GetTrainableParamName(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + std::string cnode_name = AnfAlgo::GetCNodeName(cnode); + if (kNameToIdxMap.count(cnode_name) == 0) { + return ""; + } + const OptimParamNameToIndex &index_info = kNameToIdxMap.at(cnode_name); + size_t weight_idx = index_info.at("inputs").at(kWeight); + AnfNodePtr weight_node = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cnode, weight_idx), 0).first; + MS_EXCEPTION_IF_NULL(weight_node); + if (!weight_node->isa()) { + MS_LOG(EXCEPTION) << weight_idx << " input of " << cnode_name << " is not a Parameter."; + } + return weight_node->fullname_with_scope(); +} + +bool Executor::InitParamAggregator(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + const auto &cnodes = func_graph->GetOrderedCnodes(); + for (const auto &cnode : cnodes) { + MS_EXCEPTION_IF_NULL(cnode); + const std::string ¶m_name = GetTrainableParamName(cnode); + if (param_name.empty()) { + continue; + } + if (param_aggrs_.count(param_name) != 0) { + MS_LOG(WARNING) << param_name << " already has its control flow."; + continue; + } + + std::shared_ptr param_aggr = std::make_shared(); + MS_EXCEPTION_IF_NULL(param_aggr); + param_names_.push_back(param_name); + param_aggrs_[param_name] = param_aggr; + parameter_mutex_[param_name]; + param_aggr->Init(cnode, aggregation_count_); + MS_LOG(DEBUG) << "Initializing control flow for param_name " << param_name << " success."; + } + return true; +} +} // namespace server +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/server/executor.h b/mindspore/ccsrc/ps/server/executor.h new file mode 100644 index 00000000000..0befb3dd304 --- /dev/null +++ b/mindspore/ccsrc/ps/server/executor.h @@ -0,0 +1,125 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_SERVER_EXECUTOR_H_ +#define MINDSPORE_CCSRC_PS_SERVER_EXECUTOR_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "ps/server/common.h" +#include "ps/server/parameter_aggregator.h" + +namespace mindspore { +namespace ps { +namespace server { +// Executor is the entrance for server to handle aggregation, optimizing, model querying, etc. It handles +// logics relevant to kernel launching. +class Executor { + public: + static Executor &GetInstance() { + static Executor instance; + return instance; + } + + // FuncGraphPtr func_graph is the graph compiled by the frontend. aggregation_count is the number which will + // be used for aggregators. + // As noted in header file parameter_aggregator.h, we create aggregators by trainable parameters, which is the + // optimizer cnode's input. So we need to initialize server executor using func_graph. + void Init(const FuncGraphPtr &func_graph, size_t aggregation_count); + + // Called in parameter server training mode to do Push operation. + // For the same trainable parameter, HandlePush method must be called aggregation_count_ times before it's considered + // as completed. + bool HandlePush(const std::string ¶m_name, const UploadData &upload_data); + + // Called in parameter server training mode to do Pull operation. + // Returns the value of parameter param_name. + // HandlePull method must be called the same times as HandlePush is called before it's considered as + // completed. + AddressPtr HandlePull(const std::string ¶m_name); + + // Called in federated learning training mode. Update value for parameter param_name. + bool HandleModelUpdate(const std::string ¶m_name, const UploadData &upload_data); + + // Called in asynchronous federated learning training mode. Update current model with the new feature map + // asynchronously. + bool HandleModelUpdateAsync(const std::map &feature_map); + + // Called in asynchronous federated learning training mode. Returns whole model in key-value where key refers to the + // parameter name. + std::map HandleAsyncGetModel(); + + // Forcibly overwrite specific weights in overwriteWeights message. + bool HandleOverwriteWeightsByKey(const std::map &feature_map); + + // Returns value for multiple trainable parameters passed by weight_names. + std::map HandleGetWeightsByKey(const std::vector ¶m_names); + + // Reset the aggregation status for all aggregation kernels in the server. + void ResetAggregationStatus(); + + // Judge whether aggregation processes for all weights/gradients are completed. + bool IsAllWeightAggregationDone(); + + // Judge whether the aggregation processes for the given param_names are completed. + bool IsWeightAggrDone(const std::vector ¶m_names); + + // Returns whole model in key-value where key refers to the parameter name. + std::map GetModel(); + + // Returns whether the executor singleton is already initialized. + bool initialized() const; + + const std::vector ¶m_names() const; + + private: + Executor() {} + ~Executor() = default; + Executor(const Executor &) = delete; + Executor &operator=(const Executor &) = delete; + + // Returns the trainable parameter name parsed from this cnode. + std::string GetTrainableParamName(const CNodePtr &cnode); + + // Server's graph is basically the same as Worker's graph, so we can get all information from func_graph for later + // computations. Including forward and backward propagation, aggregation, optimizing, etc. + bool InitParamAggregator(const FuncGraphPtr &func_graph); + + bool initialized_; + size_t aggregation_count_; + std::vector param_names_; + + // The map for trainable parameter names and its ParameterAggregator, as noted in the header file + // parameter_aggregator.h + std::map> param_aggrs_; + + // The mutex ensures that the operation on whole model is threadsafe. + // The whole model is constructed by all trainable parameters. + std::mutex model_mutex_; + + // Because ParameterAggregator is not threadsafe, we have to create mutex for each ParameterAggregator so we can + // acquire lock before calling its method. + std::map parameter_mutex_; +}; +} // namespace server +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_SERVER_EXECUTOR_H_ diff --git a/mindspore/ccsrc/ps/server/iteration_timer.cc b/mindspore/ccsrc/ps/server/iteration_timer.cc new file mode 100644 index 00000000000..4d4efab2d94 --- /dev/null +++ b/mindspore/ccsrc/ps/server/iteration_timer.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/server/iteration_timer.h" + +namespace mindspore { +namespace ps { +namespace server { +void IterationTimer::Start(const std::chrono::milliseconds &duration) { + if (running_.load()) { + MS_LOG(WARNING) << "The timer already started."; + return; + } + running_ = true; + end_time_ = CURRENT_TIME_MILLI + duration; + monitor_thread_ = std::thread([&]() { + while (running_.load()) { + if (CURRENT_TIME_MILLI > end_time_) { + timeout_callback_(); + running_ = false; + } + // The time tick is 1 millisecond. + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + }); + monitor_thread_.detach(); +} + +void IterationTimer::Stop() { running_ = false; } + +void IterationTimer::SetTimeOutCallBack(const TimeOutCb &timeout_cb) { + timeout_callback_ = timeout_cb; + return; +} + +bool IterationTimer::IsTimeOut(const std::chrono::milliseconds ×tamp) { + return timestamp > end_time_ ? true : false; +} + +bool IterationTimer::IsRunning() { return running_; } +} // namespace server +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/server/iteration_timer.h b/mindspore/ccsrc/ps/server/iteration_timer.h new file mode 100644 index 00000000000..8054e23a687 --- /dev/null +++ b/mindspore/ccsrc/ps/server/iteration_timer.h @@ -0,0 +1,64 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_SERVER_ITERATION_TIMER_H_ +#define MINDSPORE_CCSRC_PS_SERVER_ITERATION_TIMER_H_ + +#include +#include +#include +#include +#include "ps/server/common.h" + +namespace mindspore { +namespace ps { +namespace server { +// IterationTimer controls the time window for the purpose of eliminating trailing time of each iteration. +class IterationTimer { + public: + IterationTimer() : running_(false), end_time_(0) {} + ~IterationTimer() = default; + + // Start timing. The timer will stop after parameter 'duration' milliseconds. + void Start(const std::chrono::milliseconds &duration); + + // Caller could use this method to manually stop timing, otherwise the timer will keep timing until it expires. + void Stop(); + + // Set the callback which will be called when the timer expires. + void SetTimeOutCallBack(const TimeOutCb &timeout_cb); + + // Judge whether current timestamp is out of time window's range since the Start function is called. + bool IsTimeOut(const std::chrono::milliseconds ×tamp); + + // Judge whether the timer is keeping timing. + bool IsRunning(); + + private: + // The running state for the timer. + std::atomic running_; + + // The timestamp in millesecond at which the timer should stop timing. + std::chrono::milliseconds end_time_; + + // The thread that keeps timing and call timeout_callback_ when the timer expires. + std::thread monitor_thread_; + TimeOutCb timeout_callback_; +}; +} // namespace server +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_SERVER_ITERATION_TIMER_H_ diff --git a/mindspore/ccsrc/ps/server/kernel/aggregation_kernel.h b/mindspore/ccsrc/ps/server/kernel/aggregation_kernel.h new file mode 100644 index 00000000000..5c44afd200e --- /dev/null +++ b/mindspore/ccsrc/ps/server/kernel/aggregation_kernel.h @@ -0,0 +1,95 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_H_ +#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "ps/server/common.h" +#include "ps/server/memory_register.h" +#include "ps/server/kernel/params_info.h" + +namespace mindspore { +namespace ps { +namespace server { +namespace kernel { +// AggregationKernel is the kernel for weight, grad or other kinds of parameters' aggregation. +// For example, dense gradients accumulation, federated average, etc. +// Normally the aggregation process in AggregationKernel is like a finite-state machine: +// Initial->Aggregating->Aggregation done->Initial. +class AggregationKernel : public CPUKernel { + public: + AggregationKernel() : name_(""), done_(false), done_count_(0), accum_count_(0) {} + virtual ~AggregationKernel() = default; + + // InitKernel and Launch methods are inherited from pure virtual function of CPUKernel so it must have implementation. + virtual void InitKernel(const CNodePtr &kernel_node) {} + virtual bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) { + return true; + } + + // Server kernel's memory allocation method, which is different from the workflow in + // Session(GPUSession/CPUSession/AscendSession). + // virtual void AssignMemory(const CNodePtr &kernel_node, std::shared_ptr memory_register) = 0; + + // Set the cumulative count this aggregation kernel needs before aggregation is done. + void set_done_count(size_t count) { done_count_ = count; } + + // So we use Reset to set the finite-state machine state to Initial after considering this round of aggregation is + // done. + virtual void Reset() = 0; + + virtual bool IsAggregationDone() = 0; + + // Setter and getter of kernels parameters information. + void set_params_info(const ParamsInfo ¶ms_info) { params_info_ = params_info; } + const std::vector &input_names() { return params_info_.inputs_names(); } + const std::vector &workspace_names() { return params_info_.workspace_names(); } + const std::vector &output_names() { return params_info_.outputs_names(); } + + // Returns information about whether some inputs should reuse kernel node inputs memory. + const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info() { return reuse_kernel_node_inputs_info_; } + + protected: + virtual void GenerateReuseKernelNodeInfo() = 0; + // Aggregation kernel's name which is set by kernel register function. + std::string name_; + + // The aggregation is considered done after done_count_ times of accumulation. + bool done_; + + // Cumulative count this aggregation kernel needs before aggregation is done. + size_t done_count_; + + // Current cumulative count. + size_t accum_count_; + + // Parameters information used for kernel register, memory assignment, etc. + ParamsInfo params_info_; + + // Information about server kernel reusing kernel node inputs memory from the front end. + // Key refers to the server kernel's input index. Value refers to the kernel node's input index. + ReuseKernelNodeInfo reuse_kernel_node_inputs_info_; +}; +} // namespace kernel +} // namespace server +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_H_ diff --git a/mindspore/ccsrc/ps/server/kernel/aggregation_kernel_factory.cc b/mindspore/ccsrc/ps/server/kernel/aggregation_kernel_factory.cc new file mode 100644 index 00000000000..f2179e4252e --- /dev/null +++ b/mindspore/ccsrc/ps/server/kernel/aggregation_kernel_factory.cc @@ -0,0 +1,71 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/server/kernel/aggregation_kernel_factory.h" +#include + +namespace mindspore { +namespace ps { +namespace server { +namespace kernel { +bool AggregationKernelFactory::Matched(const ParamsInfo ¶ms_info, const CNodePtr &kernel_node) { + std::string cnode_name = AnfAlgo::GetCNodeName(kernel_node); + if (kNameToIdxMap.count(cnode_name) == 0) { + MS_LOG(ERROR) << "Can't find index info for kernel " << cnode_name; + return false; + } + + auto input_name_to_idx = kNameToIdxMap.at(cnode_name).at("inputs"); + size_t input_num = params_info.inputs_num(); + for (size_t i = 0; i < input_num; i++) { + auto one_input_name_type = params_info.inputs_name_type(i); + std::string name = one_input_name_type.first; + if (input_name_to_idx.count(name) == 0) { + MS_LOG(DEBUG) << cnode_name << " does not have input named " << name + << ". This is the new input for this aggregation kernel."; + continue; + } + size_t input_idx = input_name_to_idx.at(name); + TypeId kernel_node_input_type = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_idx); + TypeId registered_input_type = one_input_name_type.second; + if (registered_input_type != kernel_node_input_type) { + return false; + } + } + + auto output_name_to_idx = kNameToIdxMap.at(cnode_name).at("outputs"); + size_t output_num = params_info.outputs_num(); + for (size_t i = 0; i < output_num; i++) { + auto one_output_name_type = params_info.outputs_name_type(i); + std::string name = one_output_name_type.first; + if (output_name_to_idx.count(name) == 0) { + MS_LOG(DEBUG) << cnode_name << " does not have output named " << name + << ". This is the new output for this aggregation kernel."; + continue; + } + size_t output_idx = output_name_to_idx.at(name); + TypeId kernel_node_output_type = AnfAlgo::GetOutputInferDataType(kernel_node, output_idx); + TypeId registered_output_type = one_output_name_type.second; + if (registered_output_type != kernel_node_output_type) { + return false; + } + } + return true; +} +} // namespace kernel +} // namespace server +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/server/kernel/aggregation_kernel_factory.h b/mindspore/ccsrc/ps/server/kernel/aggregation_kernel_factory.h new file mode 100644 index 00000000000..87c0ead82f7 --- /dev/null +++ b/mindspore/ccsrc/ps/server/kernel/aggregation_kernel_factory.h @@ -0,0 +1,71 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_FACTORY_H_ +#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_FACTORY_H_ + +#include +#include +#include +#include "ps/server/kernel/kernel_factory.h" +#include "ps/server/kernel/aggregation_kernel.h" + +namespace mindspore { +namespace ps { +namespace server { +namespace kernel { +using AggregationKernelCreator = std::function()>; +class AggregationKernelFactory : public KernelFactory, AggregationKernelCreator> { + public: + static AggregationKernelFactory &GetInstance() { + static AggregationKernelFactory instance; + return instance; + } + + private: + AggregationKernelFactory() = default; + ~AggregationKernelFactory() override = default; + AggregationKernelFactory(const AggregationKernelFactory &) = delete; + AggregationKernelFactory &operator=(const AggregationKernelFactory &) = delete; + + // Judge whether the server aggregation kernel can be created according to registered ParamsInfo. + bool Matched(const ParamsInfo ¶ms_info, const CNodePtr &kernel_node) override; +}; + +class AggregationKernelRegister { + public: + AggregationKernelRegister(const std::string &name, const ParamsInfo ¶ms_info, + AggregationKernelCreator &&creator) { + AggregationKernelFactory::GetInstance().Register(name, params_info, std::move(creator)); + } +}; + +// Register aggregation kernel with one template type T. +#define REG_AGGREGATION_KERNEL(NAME, PARAMS_INFO, CLASS, T) \ + static_assert(std::is_base_of>::value, " must be base of AggregationKernel"); \ + static const AggregationKernelRegister g_##NAME##_##T##_aggregation_kernel_reg( \ + #NAME, PARAMS_INFO, []() { return std::make_shared>(); }); + +// Register aggregation kernel with two template types: T and S. +#define REG_AGGREGATION_KERNEL_TWO(NAME, PARAMS_INFO, CLASS, T, S) \ + static_assert(std::is_base_of>::value, " must be base of AggregationKernel"); \ + static const AggregationKernelRegister g_##NAME##_##T##_##S##_aggregation_kernel_reg( \ + #NAME, PARAMS_INFO, []() { return std::make_shared>(); }); +} // namespace kernel +} // namespace server +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_FACTORY_H_ diff --git a/mindspore/ccsrc/ps/server/kernel/apply_momentum_kernel.cc b/mindspore/ccsrc/ps/server/kernel/apply_momentum_kernel.cc new file mode 100644 index 00000000000..218fd982bea --- /dev/null +++ b/mindspore/ccsrc/ps/server/kernel/apply_momentum_kernel.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/server/kernel/apply_momentum_kernel.h" + +namespace mindspore { +namespace ps { +namespace server { +namespace kernel { +REG_OPTIMIZER_KERNEL(ApplyMomentum, + ParamsInfo() + .AddInputNameType(kWeight, kNumberTypeFloat32) + .AddInputNameType(kAccumulation, kNumberTypeFloat32) + .AddInputNameType(kLearningRate, kNumberTypeFloat32) + .AddInputNameType(kGradient, kNumberTypeFloat32) + .AddInputNameType(kMomentum, kNumberTypeFloat32), + ApplyMomentumKernel, float) +} // namespace kernel +} // namespace server +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/server/kernel/apply_momentum_kernel.h b/mindspore/ccsrc/ps/server/kernel/apply_momentum_kernel.h new file mode 100644 index 00000000000..28089a55698 --- /dev/null +++ b/mindspore/ccsrc/ps/server/kernel/apply_momentum_kernel.h @@ -0,0 +1,61 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_APPLY_MOMENTUM_KERNEL_H_ +#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_APPLY_MOMENTUM_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/cpu/apply_momentum_cpu_kernel.h" +#include "ps/server/kernel/optimizer_kernel.h" +#include "ps/server/kernel/optimizer_kernel_factory.h" + +namespace mindspore { +namespace ps { +namespace server { +namespace kernel { +using mindspore::kernel::ApplyMomentumCPUKernel; +template +class ApplyMomentumKernel : public ApplyMomentumCPUKernel, public OptimizerKernel { + public: + ApplyMomentumKernel() = default; + ~ApplyMomentumKernel() override = default; + + void InitKernel(const CNodePtr &cnode) override { + ApplyMomentumCPUKernel::InitKernel(cnode); + InitServerKernelInputOutputSize(cnode); + GenerateReuseKernelNodeInfo(); + } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override { + return ApplyMomentumCPUKernel::Launch(inputs, workspace, outputs); + } + + void GenerateReuseKernelNodeInfo() override { + reuse_kernel_node_inputs_info_.insert(std::make_pair(kWeight, 0)); + reuse_kernel_node_inputs_info_.insert(std::make_pair(kAccumulation, 1)); + reuse_kernel_node_inputs_info_.insert(std::make_pair(kLearningRate, 2)); + reuse_kernel_node_inputs_info_.insert(std::make_pair(kMomentum, 4)); + return; + } +}; +} // namespace kernel +} // namespace server +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_APPLY_MOMENTUM_KERNEL_H_ diff --git a/mindspore/ccsrc/ps/server/kernel/dense_grad_accum_kernel.cc b/mindspore/ccsrc/ps/server/kernel/dense_grad_accum_kernel.cc new file mode 100644 index 00000000000..61c51f24e6e --- /dev/null +++ b/mindspore/ccsrc/ps/server/kernel/dense_grad_accum_kernel.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/server/kernel/dense_grad_accum_kernel.h" + +namespace mindspore { +namespace ps { +namespace server { +namespace kernel { +REG_AGGREGATION_KERNEL( + DenseGradAccum, + ParamsInfo().AddInputNameType(kGradient, kNumberTypeFloat32).AddInputNameType(kNewGradient, kNumberTypeFloat32), + DenseGradAccumKernel, float) +} // namespace kernel +} // namespace server +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/server/kernel/dense_grad_accum_kernel.h b/mindspore/ccsrc/ps/server/kernel/dense_grad_accum_kernel.h new file mode 100644 index 00000000000..c8cd522d76c --- /dev/null +++ b/mindspore/ccsrc/ps/server/kernel/dense_grad_accum_kernel.h @@ -0,0 +1,95 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_DENSE_GRAD_ACCUM_KERNEL_H_ +#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_DENSE_GRAD_ACCUM_KERNEL_H_ + +#include +#include +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "ps/server/kernel/aggregation_kernel.h" +#include "ps/server/kernel/aggregation_kernel_factory.h" + +namespace mindspore { +namespace ps { +namespace server { +namespace kernel { +template +class DenseGradAccumKernel : public AggregationKernel { + public: + DenseGradAccumKernel() = default; + ~DenseGradAccumKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override { + MS_EXCEPTION_IF_NULL(kernel_node); + std::string cnode_name = AnfAlgo::GetCNodeName(kernel_node); + if (kNameToIdxMap.count(cnode_name) == 0 || kNameToIdxMap.at(cnode_name).count("inputs") == 0 || + kNameToIdxMap.at(cnode_name).at("inputs").count("grad") == 0) { + MS_LOG(EXCEPTION) << "Can't find index info of grad for kernel " << cnode_name; + return; + } + size_t cnode_grad_idx = kNameToIdxMap.at(cnode_name).at("inputs").at("grad"); + std::vector grad_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, cnode_grad_idx); + size_t grad_size = std::accumulate(grad_shape.begin(), grad_shape.end(), sizeof(T), std::multiplies()); + input_size_list_.push_back(grad_size); + size_t new_grad_size = grad_size; + input_size_list_.push_back(new_grad_size); + GenerateReuseKernelNodeInfo(); + return; + } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override { + if (accum_count_ == 0) { + int ret = memset_s(inputs[0]->addr, inputs[0]->size, 0x00, inputs[0]->size); + if (ret != 0) { + MS_LOG(ERROR) << "memset_s error, errorno(" << ret << ")"; + return false; + } + } + + T *grad_addr = reinterpret_cast(inputs[0]->addr); + T *new_grad_addr = reinterpret_cast(inputs[1]->addr); + for (size_t i = 0; i < inputs[0]->size / sizeof(T); i++) { + grad_addr[i] += new_grad_addr[i]; + } + + accum_count_++; + if (accum_count_ > done_count_) { + MS_LOG(ERROR) << "accum_count_ should not be greater than done_count_ " << done_count_; + return false; + } + if (accum_count_ == done_count_) { + for (size_t i = 0; i < inputs[0]->size / sizeof(T); i++) { + grad_addr[i] /= done_count_; + } + } + return true; + } + + void Reset() { accum_count_ = 0; } + + bool IsAggregationDone() { return accum_count_ >= done_count_; } + + void GenerateReuseKernelNodeInfo() override { return; } +}; +} // namespace kernel +} // namespace server +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_DENSE_GRAD_ACCUM_KERNEL_H_ diff --git a/mindspore/ccsrc/ps/server/kernel/kernel_factory.h b/mindspore/ccsrc/ps/server/kernel/kernel_factory.h new file mode 100644 index 00000000000..2490c7c95c9 --- /dev/null +++ b/mindspore/ccsrc/ps/server/kernel/kernel_factory.h @@ -0,0 +1,92 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_KERNEL_FACTORY_H_ +#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_KERNEL_FACTORY_H_ + +#include +#include +#include +#include +#include +#include "ps/server/common.h" +#include "ps/server/kernel/params_info.h" + +namespace mindspore { +namespace ps { +namespace server { +namespace kernel { +// KernelFactory is used to select and build kernels in server. It's the base class of OptimizerKernelFactory +// and AggregationKernelFactory. + +// Unlike normal MindSpore operator kernels, the server defines multiple types of kernels. For example: Aggregation +// Kernel, Optimizer Kernel, Forward Kernel, etc. So we define KernelFactory as a template class for register of all +// types of kernels. + +// Because most information we need to create a server kernel is in func_graph passed by the front end, we create a +// server kernel based on a cnode. + +// Typename K refers to the shared_ptr of the kernel type. +// Typename C refers to the creator function of the kernel. +template +class KernelFactory { + public: + KernelFactory() = default; + virtual ~KernelFactory() = default; + + static KernelFactory &GetInstance() { + static KernelFactory instance; + return instance; + } + + // Kernels are registered by parameter information and its creator(constructor). + void Register(const std::string &name, const ParamsInfo ¶ms_info, C &&creator) { + name_to_creator_map_[name].push_back(std::make_pair(params_info, creator)); + } + + // The kernels in server are created from func_graph's kernel_node passed by the front end. + K Create(const std::string &name, const CNodePtr &kernel_node) { + if (name_to_creator_map_.count(name) == 0) { + MS_LOG(ERROR) << "Creating kernel failed: " << name << " is not registered."; + } + for (const auto &name_type_creator : name_to_creator_map_[name]) { + const ParamsInfo ¶ms_info = name_type_creator.first; + const C &creator = name_type_creator.second; + if (Matched(params_info, kernel_node)) { + auto kernel = creator(); + kernel->set_params_info(params_info); + return kernel; + } + } + return nullptr; + } + + private: + KernelFactory(const KernelFactory &) = delete; + KernelFactory &operator=(const KernelFactory &) = delete; + + // Judge whether the server kernel can be created according to registered ParamsInfo. + virtual bool Matched(const ParamsInfo ¶ms_info, const CNodePtr &kernel_node) { return true; } + + // Generally, a server kernel can correspond to several ParamsInfo which is registered by the method 'Register' in + // server kernel's *.cc files. + std::unordered_map>> name_to_creator_map_; +}; +} // namespace kernel +} // namespace server +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_KERNEL_FACTORY_H_ diff --git a/mindspore/ccsrc/ps/server/kernel/optimizer_kernel.h b/mindspore/ccsrc/ps/server/kernel/optimizer_kernel.h new file mode 100644 index 00000000000..7bc45ca6b5c --- /dev/null +++ b/mindspore/ccsrc/ps/server/kernel/optimizer_kernel.h @@ -0,0 +1,97 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_H_ +#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_H_ + +#include +#include +#include +#include +#include "backend/kernel_compiler/common_utils.h" +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "ps/server/common.h" +#include "ps/server/memory_register.h" +#include "ps/server/kernel/params_info.h" + +namespace mindspore { +namespace ps { +namespace server { +namespace kernel { +using mindspore::kernel::IsSameShape; +using mindspore::kernel::USE_NESTEROV; + +// OptimizerKernel is the kernel in server for weights' optimizing. +// Normally server's optimizer kernels should be inherited from CPU's optimzier kernels to reuse the implementation. +class OptimizerKernel : public CPUKernel { + public: + OptimizerKernel() = default; + virtual ~OptimizerKernel() = default; + + // InitKernel and Launch methods are inherited from pure virtual function of CPUKernel so it must have implementation. + virtual void InitKernel(const CNodePtr &kernel_node) {} + virtual bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) { + return true; + } + + // Server kernel's memory allocation method, which is different from the workflow in + // Session(GPUSession/CPUSession/AscendSession). + // virtual void AssignMemory(const CNodePtr &kernel_node, std::shared_ptr memory_register) = 0; + + // Setter and getter of kernels parameters information. + void set_params_info(const ParamsInfo ¶ms_info) { params_info_ = params_info; } + const std::vector &input_names() { return params_info_.inputs_names(); } + const std::vector &workspace_names() { return params_info_.workspace_names(); } + const std::vector &output_names() { return params_info_.outputs_names(); } + + // Returns information about whether some inputs should reuse kernel node inputs memory. + const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info() { return reuse_kernel_node_inputs_info_; } + + protected: + virtual void GenerateReuseKernelNodeInfo() = 0; + + void InitServerKernelInputOutputSize(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + size_t type_size = sizeof(float); + for (size_t input_index = 0; input_index < input_num; ++input_index) { + std::vector shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, input_index); + size_t tensor_size = + shape.empty() ? type_size : std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); + input_size_list_.emplace_back(tensor_size); + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + for (size_t output_index = 0; output_index < output_num; ++output_index) { + std::vector shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, output_index); + size_t tensor_size = + shape.empty() ? type_size : std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); + output_size_list_.emplace_back(tensor_size); + } + } + + // Parameters information used for kernel register, memory assignment, etc. + ParamsInfo params_info_; + + // Information about server kernel reusing kernel node inputs memory from the front end. + // Key refers to the server kernel's input index. Value refers to the kernel node's input index. + ReuseKernelNodeInfo reuse_kernel_node_inputs_info_; +}; +} // namespace kernel +} // namespace server +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_H_ diff --git a/mindspore/ccsrc/ps/server/kernel/optimizer_kernel_factory.cc b/mindspore/ccsrc/ps/server/kernel/optimizer_kernel_factory.cc new file mode 100644 index 00000000000..98465490931 --- /dev/null +++ b/mindspore/ccsrc/ps/server/kernel/optimizer_kernel_factory.cc @@ -0,0 +1,70 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/server/kernel/optimizer_kernel_factory.h" +#include + +namespace mindspore { +namespace ps { +namespace server { +namespace kernel { +bool OptimizerKernelFactory::Matched(const ParamsInfo ¶ms_info, const CNodePtr &kernel_node) { + std::string cnode_name = AnfAlgo::GetCNodeName(kernel_node); + if (kNameToIdxMap.count(cnode_name) == 0) { + MS_LOG(ERROR) << "Can't find index info for kernel " << cnode_name; + return false; + } + + auto input_name_to_idx = kNameToIdxMap.at(cnode_name).at("inputs"); + size_t input_num = params_info.inputs_num(); + for (size_t i = 0; i < input_num; i++) { + auto one_input_name_type = params_info.inputs_name_type(i); + std::string name = one_input_name_type.first; + if (input_name_to_idx.count(name) == 0) { + MS_LOG(EXCEPTION) << cnode_name << " does not have input named " << name; + return false; + } + size_t input_idx = input_name_to_idx.at(name); + TypeId kernel_node_input_type = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_idx); + TypeId registered_input_type = one_input_name_type.second; + if (registered_input_type != kernel_node_input_type) { + return false; + } + } + + auto output_name_to_idx = kNameToIdxMap.at(cnode_name).at("outputs"); + size_t output_num = params_info.outputs_num(); + for (size_t i = 0; i < output_num; i++) { + auto one_output_name_type = params_info.outputs_name_type(i); + std::string name = one_output_name_type.first; + if (output_name_to_idx.count(name) == 0) { + MS_LOG(EXCEPTION) << cnode_name << " does not have output named " << name; + return false; + } + size_t output_idx = output_name_to_idx.at(name); + TypeId kernel_node_output_type = AnfAlgo::GetOutputInferDataType(kernel_node, output_idx); + TypeId registered_output_type = one_output_name_type.second; + if (registered_output_type != kernel_node_output_type) { + return false; + } + } + + return true; +} +} // namespace kernel +} // namespace server +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/server/kernel/optimizer_kernel_factory.h b/mindspore/ccsrc/ps/server/kernel/optimizer_kernel_factory.h new file mode 100644 index 00000000000..987a93a0a52 --- /dev/null +++ b/mindspore/ccsrc/ps/server/kernel/optimizer_kernel_factory.h @@ -0,0 +1,64 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_FACTORY_H_ +#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_FACTORY_H_ + +#include +#include +#include +#include "ps/server/kernel/kernel_factory.h" +#include "ps/server/kernel/optimizer_kernel.h" + +namespace mindspore { +namespace ps { +namespace server { +namespace kernel { +using OptimizerKernelCreator = std::function()>; +class OptimizerKernelFactory : public KernelFactory, OptimizerKernelCreator> { + public: + static OptimizerKernelFactory &GetInstance() { + static OptimizerKernelFactory instance; + return instance; + } + + private: + OptimizerKernelFactory() = default; + ~OptimizerKernelFactory() override = default; + OptimizerKernelFactory(const OptimizerKernelFactory &) = delete; + OptimizerKernelFactory &operator=(const OptimizerKernelFactory &) = delete; + + // Judge whether the server optimizer kernel can be created according to registered ParamsInfo. + bool Matched(const ParamsInfo ¶ms_info, const CNodePtr &kernel_node) override; +}; + +class OptimizerKernelRegister { + public: + OptimizerKernelRegister(const std::string &name, const ParamsInfo ¶ms_info, OptimizerKernelCreator &&creator) { + OptimizerKernelFactory::GetInstance().Register(name, params_info, std::move(creator)); + } +}; + +// Register optimizer kernel with one template type T. +#define REG_OPTIMIZER_KERNEL(NAME, PARAMS_INFO, CLASS, T) \ + static_assert(std::is_base_of>::value, " must be base of OptimizerKernel"); \ + static const OptimizerKernelRegister g_##NAME##_##T##_optimizer_kernel_reg( \ + #NAME, PARAMS_INFO, []() { return std::make_shared>(); }); +} // namespace kernel +} // namespace server +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_FACTORY_H_ diff --git a/mindspore/ccsrc/ps/server/kernel/params_info.cc b/mindspore/ccsrc/ps/server/kernel/params_info.cc new file mode 100644 index 00000000000..e68b7f2a700 --- /dev/null +++ b/mindspore/ccsrc/ps/server/kernel/params_info.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/server/kernel/params_info.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace ps { +namespace server { +namespace kernel { +ParamsInfo &ParamsInfo::AddInputNameType(const std::string &name, TypeId type) { + inputs_name_type_.push_back(std::make_pair(name, type)); + inputs_names_.push_back(name); + return *this; +} + +ParamsInfo &ParamsInfo::AddWorkspaceNameType(const std::string &name, TypeId type) { + workspaces_name_type_.push_back(std::make_pair(name, type)); + workspace_names_.push_back(name); + return *this; +} + +ParamsInfo &ParamsInfo::AddOutputNameType(const std::string &name, TypeId type) { + outputs_name_type_.push_back(std::make_pair(name, type)); + outputs_names_.push_back(name); + return *this; +} + +size_t ParamsInfo::inputs_num() const { return inputs_name_type_.size(); } + +size_t ParamsInfo::outputs_num() const { return outputs_name_type_.size(); } + +const std::pair &ParamsInfo::inputs_name_type(size_t index) const { + if (index >= inputs_name_type_.size()) { + MS_LOG(EXCEPTION) << "Index " << index << " is out of bound of inputs_name_type_."; + } + return inputs_name_type_[index]; +} + +const std::pair &ParamsInfo::outputs_name_type(size_t index) const { + if (index >= outputs_name_type_.size()) { + MS_LOG(EXCEPTION) << "Index " << index << " is out of bound of outputs_name_type_."; + } + return outputs_name_type_[index]; +} + +const std::vector &ParamsInfo::inputs_names() const { return inputs_names_; } + +const std::vector &ParamsInfo::workspace_names() const { return workspace_names_; } + +const std::vector &ParamsInfo::outputs_names() const { return outputs_names_; } +} // namespace kernel +} // namespace server +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/server/kernel/params_info.h b/mindspore/ccsrc/ps/server/kernel/params_info.h new file mode 100644 index 00000000000..522a46cb41b --- /dev/null +++ b/mindspore/ccsrc/ps/server/kernel/params_info.h @@ -0,0 +1,70 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_PARAMS_INFO_H_ +#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_PARAMS_INFO_H_ + +#include +#include +#include +#include "ir/dtype/type_id.h" + +namespace mindspore { +namespace ps { +namespace server { +namespace kernel { +// ParamsInfo is used for server computation kernel's register, e.g, ApplyMomentumKernel, FedAvgKernel, etc. +// Register of a server kernel needs every inputs/workspace/outputs parameters' name and type. +// For example: +// ParamsInfo() +// .AddInputNameType("input1_name", kNumberTypeFloat32) +// .AddInputNameType("input2_name", kNumberTypeUInt64) +// .AddWorkspaceNameType("workspace1_name", kNumberTypeFloat32) +// .AddOutputNameType("output1_name", kNumberTypeUInt64) +// This invocation describes a server kernel with parameters below: +// An input with name "input1_name" and type float32. +// An input with name "input1_name" and type uint_64. +// A workspace with name "workspace1_name" and type float32. +// An output with name "output1_name" and type float32. +class ParamsInfo { + public: + ParamsInfo() = default; + ~ParamsInfo() = default; + + ParamsInfo &AddInputNameType(const std::string &name, TypeId type); + ParamsInfo &AddWorkspaceNameType(const std::string &name, TypeId type); + ParamsInfo &AddOutputNameType(const std::string &name, TypeId type); + size_t inputs_num() const; + size_t outputs_num() const; + const std::pair &inputs_name_type(size_t index) const; + const std::pair &outputs_name_type(size_t index) const; + const std::vector &inputs_names() const; + const std::vector &workspace_names() const; + const std::vector &outputs_names() const; + + private: + std::vector> inputs_name_type_; + std::vector> workspaces_name_type_; + std::vector> outputs_name_type_; + std::vector inputs_names_; + std::vector workspace_names_; + std::vector outputs_names_; +}; +} // namespace kernel +} // namespace server +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_PARAMS_INFO_H_ diff --git a/mindspore/ccsrc/ps/server/local_meta_storage.cc b/mindspore/ccsrc/ps/server/local_meta_storage.cc new file mode 100644 index 00000000000..72cc08d4f46 --- /dev/null +++ b/mindspore/ccsrc/ps/server/local_meta_storage.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/server/local_meta_storage.h" +#include + +namespace mindspore { +namespace ps { +namespace server { +void LocalMetaStorage::remove_value(const std::string &name) { + std::unique_lock lock(mtx_); + if (key_to_meta_.count(name) != 0) { + key_to_meta_.erase(key_to_meta_.find(name)); + } +} + +bool LocalMetaStorage::has_value(const std::string &name) { + std::unique_lock lock(mtx_); + return key_to_meta_.count(name) != 0; +} + +void LocalMetaStorage::set_curr_iter_num(size_t num) { + std::unique_lock lock(mtx_); + curr_iter_num_ = num; +} + +const size_t LocalMetaStorage::curr_iter_num() { + std::unique_lock lock(mtx_); + return curr_iter_num_; +} +} // namespace server +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/server/local_meta_storage.h b/mindspore/ccsrc/ps/server/local_meta_storage.h new file mode 100644 index 00000000000..3a7467e6589 --- /dev/null +++ b/mindspore/ccsrc/ps/server/local_meta_storage.h @@ -0,0 +1,88 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORAGE_H_ +#define MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORAGE_H_ + +#include +#include +#include +#include +#include "ps/server/common.h" + +namespace mindspore { +namespace ps { +namespace server { +// LocalMetaStorage class is used for metadata storage of this server process. +// For example, the current iteration number, time windows for round kernels, etc. +// LocalMetaStorage is threadsafe. +class LocalMetaStorage { + public: + static LocalMetaStorage &GetInstance() { + static LocalMetaStorage instance; + return instance; + } + + template + void put_value(const std::string &name, const T &value) { + std::unique_lock lock(mtx_); + key_to_meta_[name] = value; + } + + template + const T &value(const std::string &name) { + std::unique_lock lock(mtx_); + try { + T value = std::any_cast(key_to_meta_[name]); + return value; + } catch (const std::exception &e) { + MS_LOG(EXCEPTION) << "Value of " << name << " is not set."; + } + } + + // This method returns a reference so that user can change this value without calling put_value. + template + T &mutable_value(const std::string &name) { + std::unique_lock lock(mtx_); + try { + return std::any_cast(key_to_meta_[name]); + } catch (const std::exception &e) { + MS_LOG(EXCEPTION) << "Value of " << name << " is not set."; + } + } + + void remove_value(const std::string &name); + bool has_value(const std::string &name); + + void set_curr_iter_num(size_t num); + const size_t curr_iter_num(); + + private: + LocalMetaStorage() = default; + ~LocalMetaStorage() = default; + LocalMetaStorage(const LocalMetaStorage &) = delete; + LocalMetaStorage &operator=(const LocalMetaStorage &) = delete; + + // key_to_meta_ stores metadata with key-value format. + std::unordered_map key_to_meta_; + // This mutex makes sure that the operations on key_to_meta_ is threadsafe. + std::mutex mtx_; + size_t curr_iter_num_; +}; +} // namespace server +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORAGE_H_ diff --git a/mindspore/ccsrc/ps/server/memory_register.cc b/mindspore/ccsrc/ps/server/memory_register.cc new file mode 100644 index 00000000000..3ceb0f37a6e --- /dev/null +++ b/mindspore/ccsrc/ps/server/memory_register.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/server/memory_register.h" +#include + +namespace mindspore { +namespace ps { +namespace server { +void MemoryRegister::RegisterAddressPtr(const std::string &name, const AddressPtr &address) { + addresses_.try_emplace(name, address); +} + +void MemoryRegister::StoreFloatArray(std::unique_ptr *array) { float_arrays_.push_back(std::move(*array)); } + +void MemoryRegister::StoreInt32Array(std::unique_ptr *array) { int32_arrays_.push_back(std::move(*array)); } + +void MemoryRegister::StoreUint64Array(std::unique_ptr *array) { uint64_arrays_.push_back(std::move(*array)); } +} // namespace server +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/server/memory_register.h b/mindspore/ccsrc/ps/server/memory_register.h new file mode 100644 index 00000000000..161de5c7bac --- /dev/null +++ b/mindspore/ccsrc/ps/server/memory_register.h @@ -0,0 +1,88 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_SERVER_MEMORY_REGISTER_H_ +#define MINDSPORE_CCSRC_PS_SERVER_MEMORY_REGISTER_H_ + +#include +#include +#include +#include +#include +#include +#include "ps/server/common.h" + +namespace mindspore { +namespace ps { +namespace server { +// Memory allocated in server is normally trainable parameters, hyperparameters, gradients, etc. +// MemoryRegister registers the Memory with key-value format where key refers to address's name("grad", "weights", +// etc) and value is AddressPtr. +class MemoryRegister { + public: + MemoryRegister() = default; + ~MemoryRegister() = default; + + std::map &addresses() { return addresses_; } + void RegisterAddressPtr(const std::string &name, const AddressPtr &address); + + // In some cases, memory is passed by unique_ptr which is allocated by caller. They needs to be stored as well to + // avoid its data being released. + template + void RegisterArray(const std::string &name, std::unique_ptr *array, size_t size) { + MS_EXCEPTION_IF_NULL(array); + void *data = array->get(); + AddressPtr addr = std::make_shared
(); + addr->addr = data; + addr->size = size; + + if (typeid(T) == typeid(int)) { + auto int_arr = CastUniquePtr(array); + StoreInt32Array(&int_arr); + } else if (typeid(T) == typeid(float)) { + auto float_arr = CastUniquePtr(array); + StoreFloatArray(&float_arr); + } else if (typeid(T) == typeid(size_t)) { + auto uint64_arr = CastUniquePtr(array); + StoreUint64Array(&uint64_arr); + } else { + MS_LOG(ERROR) << "MemoryRegister does not support type " << typeid(T).name(); + return; + } + + RegisterAddressPtr(name, addr); + return; + } + + private: + std::map addresses_; + std::vector> float_arrays_; + std::vector> int32_arrays_; + std::vector> uint64_arrays_; + + void StoreInt32Array(std::unique_ptr *array); + void StoreFloatArray(std::unique_ptr *array); + void StoreUint64Array(std::unique_ptr *array); + + template + std::unique_ptr CastUniquePtr(std::unique_ptr *array) { + return std::unique_ptr{reinterpret_cast(array->release())}; + } +}; +} // namespace server +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_SERVER_MEMORY_REGISTER_H_ diff --git a/mindspore/ccsrc/ps/server/parameter_aggregator.cc b/mindspore/ccsrc/ps/server/parameter_aggregator.cc new file mode 100644 index 00000000000..fbe07634484 --- /dev/null +++ b/mindspore/ccsrc/ps/server/parameter_aggregator.cc @@ -0,0 +1,321 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/server/parameter_aggregator.h" +#include +#include +#include +#include +#include +#include + +namespace mindspore { +namespace ps { +namespace server { +bool ParameterAggregator::Init(const CNodePtr &cnode, size_t required_count) { + MS_EXCEPTION_IF_NULL(cnode); + memory_register_ = std::make_shared(); + MS_EXCEPTION_IF_NULL(memory_register_); + + required_push_count_ = required_count; + // The required_pull_count_ is the count for Pull, which should be the same as required_push_count_. + // required_pull_count_ normally used in parameter server training mode. + required_pull_count_ = required_count; + + MS_LOG(DEBUG) << "Start initializing kernels for " << AnfAlgo::GetCNodeName(cnode); + InitAggregationKernels(cnode); + InitOptimizerKernels(cnode); + return true; +} + +bool ParameterAggregator::UpdateData(const std::map &new_data) { + std::map &name_to_addr = memory_register_->addresses(); + for (const auto &data : new_data) { + const std::string &name = data.first; + if (name_to_addr.count(name) == 0) { + continue; + } + + MS_LOG(DEBUG) << "Update data for " << name << ". Destination size: " << name_to_addr[name]->size + << ". Source size: " << data.second.size; + int ret = memcpy_s(name_to_addr[name]->addr, name_to_addr[name]->size, data.second.addr, data.second.size); + if (ret != 0) { + MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; + return false; + } + } + return true; +} + +bool ParameterAggregator::LaunchAggregators() { + for (auto &aggregator_with_params : aggregation_kernel_parameters_) { + KernelParams ¶ms = aggregator_with_params.second; + std::shared_ptr aggr_kernel = aggregator_with_params.first; + RETURN_IF_NULL(aggr_kernel, false); + + bool ret = aggr_kernel->Launch(params.inputs, params.workspace, params.outputs); + if (!ret) { + MS_LOG(ERROR) << "Launching aggregation kernel " << typeid(aggr_kernel.get()).name() << " failed."; + continue; + } + } + return true; +} + +bool ParameterAggregator::LaunchOptimizers() { + for (auto &optimizer_with_params : optimizer_kernel_parameters_) { + KernelParams ¶ms = optimizer_with_params.second; + std::shared_ptr optimizer_kernel = optimizer_with_params.first; + RETURN_IF_NULL(optimizer_kernel, false); + + bool ret = optimizer_kernel->Launch(params.inputs, params.workspace, params.outputs); + if (!ret) { + MS_LOG(ERROR) << "Launching optimizer kernel " << typeid(optimizer_kernel.get()).name() << " failed."; + continue; + } + } + // As long as all the optimizer kernels are launched, consider optimizing for this ParameterAggregator as done. + optimizing_done_ = true; + return true; +} + +AddressPtr ParameterAggregator::Pull() { + if (memory_register_ == nullptr) { + MS_LOG(ERROR) + << "The memory register of ParameterAggregator is nullptr. Please initialize ParameterAggregator first."; + return nullptr; + } + + current_pull_count_++; + if (current_pull_count_ == required_pull_count_) { + pulling_done_ = true; + } + MS_LOG(DEBUG) << "The " << current_pull_count_ << " time of Pull. Pulling done status: " << pulling_done_; + + std::map &name_to_addr = memory_register_->addresses(); + return name_to_addr["weight"]; +} + +AddressPtr ParameterAggregator::GetWeight() { + if (memory_register_ == nullptr) { + MS_LOG(ERROR) + << "The memory register of ParameterAggregator is nullptr. Please initialize ParameterAggregator first."; + return nullptr; + } + std::map &name_to_addr = memory_register_->addresses(); + return name_to_addr["weight"]; +} + +void ParameterAggregator::ResetAggregationStatus() { + for (auto &aggregator_with_params : aggregation_kernel_parameters_) { + std::shared_ptr aggr_kernel = aggregator_with_params.first; + if (aggr_kernel == nullptr) { + MS_LOG(ERROR) << "The aggregation kernel is nullptr."; + continue; + } + aggr_kernel->Reset(); + } + return; +} + +void ParameterAggregator::ResetOptimizingStatus() { optimizing_done_ = false; } + +void ParameterAggregator::ResetPullingStatus() { + pulling_done_ = false; + current_pull_count_ = 0; +} + +bool ParameterAggregator::IsAggregationDone() const { + // Only consider aggregation done after each aggregation kernel is done. + for (auto &aggregator_with_params : aggregation_kernel_parameters_) { + std::shared_ptr aggr_kernel = aggregator_with_params.first; + RETURN_IF_NULL(aggr_kernel, false); + if (!aggr_kernel->IsAggregationDone()) { + return false; + } + } + return true; +} + +bool ParameterAggregator::IsOptimizingDone() const { return optimizing_done_; } + +bool ParameterAggregator::IsPullingDone() const { return pulling_done_; } + +bool ParameterAggregator::InitAggregationKernels(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + std::vector aggr_kernel_names = SelectAggregationAlgorithm(cnode); + for (const std::string &name : aggr_kernel_names) { + auto aggr_kernel = kernel::AggregationKernelFactory::GetInstance().Create(name, cnode); + if (aggr_kernel == nullptr) { + MS_LOG(EXCEPTION) << "Fail to create aggregation kernel " << name << " for " << AnfAlgo::GetCNodeName(cnode); + return false; + } + + // set_done_count must be called before InitKernel because InitKernel may use this count. + aggr_kernel->set_done_count(required_push_count_); + aggr_kernel->InitKernel(cnode); + + const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info = aggr_kernel->reuse_kernel_node_inputs_info(); + if (!AssignMemory(aggr_kernel, cnode, reuse_kernel_node_inputs_info, memory_register_)) { + MS_LOG(EXCEPTION) << "Assigning memory for kernel " << name << " failed."; + return false; + } + + if (!GenerateAggregationKernelParams(aggr_kernel, memory_register_)) { + MS_LOG(EXCEPTION) << "Generating aggregation kernel parameters for " << name << " failed."; + return false; + } + } + return true; +} + +bool ParameterAggregator::InitOptimizerKernels(const CNodePtr &cnode) { + // if (PSContext::instance()->server_mode() == kServerModeFL) { + // MS_LOG(DEBUG) << "Federated learning mode doesn't need optimizer kernel."; + // return false; + // } + MS_EXCEPTION_IF_NULL(cnode); + const std::string &name = AnfAlgo::GetCNodeName(cnode); + auto optimizer_kernel = kernel::OptimizerKernelFactory::GetInstance().Create(name, cnode); + if (optimizer_kernel == nullptr) { + MS_LOG(EXCEPTION) << "Failed to create optimizer kernel for " << name; + return false; + } + + optimizer_kernel->InitKernel(cnode); + + const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info = optimizer_kernel->reuse_kernel_node_inputs_info(); + if (!AssignMemory(optimizer_kernel, cnode, reuse_kernel_node_inputs_info, memory_register_)) { + MS_LOG(EXCEPTION) << "Assigning memory for kernel " << name << " failed."; + return false; + } + + if (!GenerateOptimizerKernelParams(optimizer_kernel, memory_register_)) { + MS_LOG(ERROR) << "Generating optimizer kernel parameters failed."; + return false; + } + return true; +} + +template +bool ParameterAggregator::AssignMemory(K server_kernel, const CNodePtr &cnode, + const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info, + std::shared_ptr memory_register) { + MS_EXCEPTION_IF_NULL(server_kernel); + MS_EXCEPTION_IF_NULL(cnode); + + const std::vector &input_names = server_kernel->input_names(); + const std::vector &input_size_list = server_kernel->GetInputSizeList(); + if (input_names.size() != input_size_list.size()) { + MS_LOG(EXCEPTION) << "Server kernel " << typeid(server_kernel.get()).name() + << " input number is not matched: input_names size is " << input_names.size() + << ", input_size_list size is " << input_size_list.size(); + return false; + } + + if (reuse_kernel_node_inputs_info.size() > input_names.size()) { + MS_LOG(EXCEPTION) << "The reuse kernel node information number is invalid: got " + << reuse_kernel_node_inputs_info.size() << ", but input_names size is " << input_names.size(); + return false; + } + + for (size_t i = 0; i < input_names.size(); i++) { + const std::string &name = input_names[i]; + if (memory_register->addresses().count(name) != 0) { + MS_LOG(DEBUG) << "The memory for " << name << " is already assigned."; + continue; + } + if (reuse_kernel_node_inputs_info.count(name) != 0) { + // Reusing memory of the kernel node means the memory of the input is already assigned by the front end, which + // is to say, the input node is a parameter node. + size_t index = reuse_kernel_node_inputs_info.at(name); + MS_LOG(INFO) << "Try to reuse memory of kernel node " << AnfAlgo::GetCNodeName(cnode) << " for parameter " << name + << ", kernel node index " << index; + AddressPtr input_addr = GenerateParameterNodeAddrPtr(cnode, index); + MS_EXCEPTION_IF_NULL(input_addr); + memory_register->RegisterAddressPtr(name, input_addr); + } else { + MS_LOG(INFO) << "Assign new memory for " << name; + auto input_addr = std::make_unique(input_size_list[i]); + MS_EXCEPTION_IF_NULL(input_addr); + memory_register->RegisterArray(name, &input_addr, input_size_list[i]); + } + } + return true; +} + +bool ParameterAggregator::GenerateAggregationKernelParams(const std::shared_ptr aggr_kernel, + const std::shared_ptr memory_register) { + RETURN_IF_NULL(aggr_kernel, false); + RETURN_IF_NULL(memory_register, false); + KernelParams aggr_params = {}; + + const std::vector &input_names = aggr_kernel->input_names(); + std::transform(input_names.begin(), input_names.end(), std::back_inserter(aggr_params.inputs), + [&](const std::string &name) { return memory_register->addresses()[name]; }); + + const std::vector &workspace_names = aggr_kernel->workspace_names(); + std::transform(workspace_names.begin(), workspace_names.end(), std::back_inserter(aggr_params.workspace), + [&](const std::string &name) { return memory_register->addresses()[name]; }); + + const std::vector &output_names = aggr_kernel->output_names(); + std::transform(output_names.begin(), output_names.end(), std::back_inserter(aggr_params.outputs), + [&](const std::string &name) { return memory_register->addresses()[name]; }); + + aggregation_kernel_parameters_.push_back(std::make_pair(aggr_kernel, aggr_params)); + return true; +} + +bool ParameterAggregator::GenerateOptimizerKernelParams(const std::shared_ptr optimizer_kernel, + const std::shared_ptr memory_register) { + RETURN_IF_NULL(optimizer_kernel, false); + RETURN_IF_NULL(memory_register, false); + KernelParams optimizer_params = {}; + + const std::vector &input_names = optimizer_kernel->input_names(); + std::transform(input_names.begin(), input_names.end(), std::back_inserter(optimizer_params.inputs), + [&](const std::string &name) { return memory_register->addresses()[name]; }); + + const std::vector &workspace_names = optimizer_kernel->workspace_names(); + std::transform(workspace_names.begin(), workspace_names.end(), std::back_inserter(optimizer_params.workspace), + [&](const std::string &name) { return memory_register->addresses()[name]; }); + + const std::vector &output_names = optimizer_kernel->output_names(); + std::transform(output_names.begin(), output_names.end(), std::back_inserter(optimizer_params.outputs), + [&](const std::string &name) { return memory_register->addresses()[name]; }); + + optimizer_kernel_parameters_.push_back(std::make_pair(optimizer_kernel, optimizer_params)); + return true; +} + +std::vector ParameterAggregator::SelectAggregationAlgorithm(const CNodePtr &cnode) { + std::vector aggregation_algorithm = {}; + MS_LOG(INFO) << "Aggregation algorithm selection result: " << aggregation_algorithm; + return aggregation_algorithm; +} + +template bool ParameterAggregator::AssignMemory(std::shared_ptr server_kernel, + const CNodePtr &cnode, + const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info, + std::shared_ptr memory_register); + +template bool ParameterAggregator::AssignMemory(std::shared_ptr server_kernel, + const CNodePtr &cnode, + const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info, + std::shared_ptr memory_register); +} // namespace server +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/server/parameter_aggregator.h b/mindspore/ccsrc/ps/server/parameter_aggregator.h new file mode 100644 index 00000000000..8344f75f123 --- /dev/null +++ b/mindspore/ccsrc/ps/server/parameter_aggregator.h @@ -0,0 +1,139 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_SERVER_PARAMETER_AGGREGATOR_H_ +#define MINDSPORE_CCSRC_PS_SERVER_PARAMETER_AGGREGATOR_H_ + +#include +#include +#include +#include +#include +#include "ps/server/common.h" +#include "ps/server/memory_register.h" +#include "ps/server/kernel/aggregation_kernel_factory.h" +#include "ps/server/kernel/optimizer_kernel_factory.h" + +namespace mindspore { +namespace ps { +namespace server { +// Encapsulate the parameters for a kernel into a struct to make it convenient for ParameterAggregator to launch server +// kernels. +typedef struct { + std::vector inputs; + std::vector workspace; + std::vector outputs; +} KernelParams; + +// ParameterAggregator includes methods for aggregating gradients and optimizing weights(launching aggregation and +// optimizer kernels), getting weights, etc. It's not thread-safe, which means the caller must acquire lock before +// calling ParameterAggregator methods concurrently. + +// Each ParameterAggregator is corresponding to one weight for now. + +// ParameterAggregator is stateful because the process of aggregation and optimizing could be stateful. +// For example, the finite-state machine for the ParameterAggregator in parameter server training mode is below: +// Initial->Aggregating->Aggregation done->Optimizing->Optimizing done->Pulling->Pull done->Initial. +class ParameterAggregator { + public: + ParameterAggregator() + : server_mode_(ServerMode::PARAMETER_SERVER), + required_push_count_(0), + required_pull_count_(0), + current_pull_count_(0), + aggregation_done_(false), + optimizing_done_(false), + pulling_done_(true), + memory_register_(nullptr) {} + ~ParameterAggregator() = default; + + // Initialize ParameterAggregator with a cnode. This cnode is normally a optimizer kernel for now. + // The parameter required_count helps ParameterAggregator to judge the current status if it's stateful. + bool Init(const CNodePtr &cnode, size_t required_count = 0); + + // Update old data stored in ParameterAggregator with new data. + // The data could have many meanings: weights, gradients, learning_rate, momentum, etc. + bool UpdateData(const std::map &new_data); + + // Launch aggregators/optimizers of this ParameterAggregator in order. + bool LaunchAggregators(); + bool LaunchOptimizers(); + + // The implementation for primitive Pull in parameter server training mode. + // Every call of this method will increase the count for pull by 1. + AddressPtr Pull(); + + // Different from the method Pull, this method simply returns the weight of this ParameterAggregator without causing + // any change of status. + AddressPtr GetWeight(); + + // After aggregation/optimizing/pulling of one iteration is done, caller must reset the status to ensure the + // correctness of the aggregation/optimizing/pulling for next iteration. + void ResetAggregationStatus(); + void ResetOptimizingStatus(); + void ResetPullingStatus(); + + // Returns the aggregation/optimizing/pulling status to the caller. + bool IsAggregationDone() const; + bool IsOptimizingDone() const; + bool IsPullingDone() const; + + private: + // Initializing aggregation/optimizer kenerls based on the cnode. The reason of this is described in the file + // kernel/kernel_factory.h. + bool InitAggregationKernels(const CNodePtr &cnode); + bool InitOptimizerKernels(const CNodePtr &cnode); + + // Assign memory for server kernel K(AggregationKernel/OptimizerKernel). + // The memory assigned can be accessed by MemoryRegister. The memory could be weights, gradients, learning_rate, + // momentum, etc. + template + bool AssignMemory(K server_kernel, const CNodePtr &cnode, const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info, + std::shared_ptr memory_register); + + // Generate kernel parameters for aggregation/optimizer kernels. All the parameters is registered and stored in + // memory_register. + bool GenerateAggregationKernelParams(const std::shared_ptr aggr_kernel, + const std::shared_ptr memory_register); + bool GenerateOptimizerKernelParams(const std::shared_ptr optim_kernel, + const std::shared_ptr memory_register); + + // The selection of the aggregation algorithm depends on multiple factors. For example, server mode, user + // configuration, etc. + std::vector SelectAggregationAlgorithm(const CNodePtr &cnode); + + ServerMode server_mode_; + size_t required_push_count_; + size_t required_pull_count_; + size_t current_pull_count_; + + // The status of aggregation/optimizing/pulling. + bool aggregation_done_; + bool optimizing_done_; + bool pulling_done_; + + // ParameterAggregator stores all data that it needs for aggregation, optimizing, etc. + std::shared_ptr memory_register_; + + // Update could have multiple aggregation and optimizer server kernels. + // Here stores multiple pairs of server kernels to parameters of their Launch function. + std::vector, KernelParams>> aggregation_kernel_parameters_; + std::vector, KernelParams>> optimizer_kernel_parameters_; +}; +} // namespace server +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_SERVER_PARAMETER_AGGREGATOR_H_ diff --git a/tests/ut/cpp/CMakeLists.txt b/tests/ut/cpp/CMakeLists.txt index 245ba4c264d..55899f4b9f7 100644 --- a/tests/ut/cpp/CMakeLists.txt +++ b/tests/ut/cpp/CMakeLists.txt @@ -168,6 +168,7 @@ list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/parameter_serve list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/gpu/gpu_ps_cache.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc") +list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/server/kernel/apply_momentum_kernel.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_fusion.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/post_batch_norm_add_relu_fusion.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc")