forked from mindspore-Ecosystem/mindspore
!15916 Add federated learning server part1
From: @zpac Reviewed-by: Signed-off-by:
This commit is contained in:
commit
9416502e90
|
@ -39,6 +39,20 @@ if(NOT ENABLE_GPU)
|
|||
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/gpu/gpu_ps_cache.cc")
|
||||
endif()
|
||||
|
||||
if(WIN32 OR NOT ENABLE_CPU)
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/apply_momentum_kernel.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/aggregation_kernel_factory.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/dense_grad_accum_kernel.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/optimizer_kernel_factory.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/params_info.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/consistent_hash_ring.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/iteration_timer.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/local_meta_storage.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/memory_register.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/parameter_aggregator.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/executor.cc")
|
||||
endif()
|
||||
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_prefetch.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_channel.cc")
|
||||
add_subdirectory(ps_cache)
|
||||
|
|
|
@ -0,0 +1,189 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_COMMON_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_COMMON_H_
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <numeric>
|
||||
#include <climits>
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
#include "proto/ps.pb.h"
|
||||
#include "ir/anf.h"
|
||||
#include "utils/utils.h"
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "ps/ps_context.h"
|
||||
#include "ps/core/communicator/http_message_handler.h"
|
||||
#include "ps/core/communicator/tcp_server.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
// Definitions for the server framework.
|
||||
enum ServerMode { PARAMETER_SERVER = 0, FL_SERVER };
|
||||
enum CommType { HTTP = 0, TCP };
|
||||
enum AggregationType { FedAvg = 0, FedAdam, FedAdagarg, FedMeta, qffl, DenseGradAccum, SparseGradAccum };
|
||||
|
||||
using kernel::Address;
|
||||
using kernel::AddressPtr;
|
||||
using kernel::CPUKernel;
|
||||
using TimeOutCb = std::function<void(void)>;
|
||||
using StopTimerCb = std::function<void(void)>;
|
||||
using FinishIterCb = std::function<void(void)>;
|
||||
using FinalizeCb = std::function<void(void)>;
|
||||
|
||||
// Information about whether server kernel will reuse kernel node memory from the front end.
|
||||
// Key refers to the server kernel's parameter name, like "weights", "grad", "learning_rate".
|
||||
// Value refers to the kernel node's parameter index.
|
||||
using ReuseKernelNodeInfo = std::map<std::string, size_t>;
|
||||
|
||||
// UploadData refers to the data which is uploaded by workers.
|
||||
// Key refers to the data name. For example: "weights", "grad", "learning_rate", etc. This will be set by the worker.
|
||||
// Value refers to the data of the key.
|
||||
|
||||
// We use Address instead of AddressPtr because:
|
||||
// 1. Address doesn't need to call make_shared<T> so it has better performance.
|
||||
// 2. The data uploaded by worker is normally parsed from FlatterBuffers or ProtoBuffer. For example: learning rate, new
|
||||
// weights, etc. Address is enough to store these data.
|
||||
|
||||
// Pay attention that Address only stores the void* pointer of the data, so the data must not be released before the
|
||||
// related logic is done.
|
||||
using UploadData = std::map<std::string, Address>;
|
||||
|
||||
constexpr auto kWeight = "weight";
|
||||
constexpr auto kAccumulation = "accum";
|
||||
constexpr auto kLearningRate = "lr";
|
||||
constexpr auto kGradient = "grad";
|
||||
constexpr auto kNewGradient = "new_grad";
|
||||
constexpr auto kMomentum = "momentum";
|
||||
constexpr auto kIndices = "indices";
|
||||
constexpr auto kAdamM = "m";
|
||||
constexpr auto kAdamV = "v";
|
||||
constexpr auto kAdamBeta1Power = "beta1_power";
|
||||
constexpr auto kAdamBeta2Power = "beta2_power";
|
||||
constexpr auto kAdamBeta1 = "beta1";
|
||||
constexpr auto kAdamBeta2 = "beta2";
|
||||
constexpr auto kAdamEps = "eps";
|
||||
constexpr auto kFtrlLinear = "linear";
|
||||
|
||||
// OptimParamNameToIndex represents every inputs/workspace/outputs parameter's offset when an optimizer kernel is
|
||||
// launched.
|
||||
using OptimParamNameToIndex = std::map<std::string, std::map<std::string, size_t>>;
|
||||
const OptimParamNameToIndex kMomentumNameToIdx = {
|
||||
{"inputs", {{kWeight, 0}, {kAccumulation, 1}, {kLearningRate, 2}, {kGradient, 3}, {kMomentum, 4}}}, {"outputs", {}}};
|
||||
const OptimParamNameToIndex kAdamNameToIdx = {{"inputs",
|
||||
{{kWeight, 0},
|
||||
{kAdamM, 1},
|
||||
{kAdamV, 2},
|
||||
{kAdamBeta1Power, 3},
|
||||
{kAdamBeta2Power, 4},
|
||||
{kLearningRate, 5},
|
||||
{kAdamBeta1, 6},
|
||||
{kAdamBeta2, 7},
|
||||
{kAdamEps, 8},
|
||||
{kGradient, 9}}},
|
||||
{"outputs", {}}};
|
||||
const OptimParamNameToIndex kSparseAdamNameToIdx = {{"inputs",
|
||||
{{kWeight, 0},
|
||||
{kAdamM, 1},
|
||||
{kAdamV, 2},
|
||||
{kAdamBeta1Power, 3},
|
||||
{kAdamBeta2Power, 4},
|
||||
{kLearningRate, 5},
|
||||
{kAdamBeta1, 6},
|
||||
{kAdamBeta1, 7},
|
||||
{kAdamEps, 8},
|
||||
{kGradient, 9},
|
||||
{kIndices, 10}}},
|
||||
{"outputs", {}}};
|
||||
const OptimParamNameToIndex kSparseFtrlNameToIdx = {
|
||||
{"inputs", {{kWeight, 0}, {kAccumulation, 1}, {kFtrlLinear, 2}, {kGradient, 3}, {kIndices, 4}}}, {"outputs", {}}};
|
||||
const std::map<std::string, OptimParamNameToIndex> kNameToIdxMap = {
|
||||
{kApplyMomentumOpName, kMomentumNameToIdx},
|
||||
{kFusedSparseAdamName, kSparseAdamNameToIdx},
|
||||
{kSparseApplyFtrlOpName, kSparseFtrlNameToIdx},
|
||||
{kApplyAdamOpName, kAdamNameToIdx},
|
||||
};
|
||||
|
||||
constexpr uint32_t kLeaderServerRank = 0;
|
||||
constexpr size_t kWorkerMgrThreadPoolSize = 32;
|
||||
constexpr size_t kWorkerMgrMaxTaskNum = 64;
|
||||
constexpr size_t kCipherMgrThreadPoolSize = 32;
|
||||
constexpr size_t kCipherMgrMaxTaskNum = 64;
|
||||
constexpr size_t kExecutorThreadPoolSize = 32;
|
||||
constexpr size_t kExecutorMaxTaskNum = 32;
|
||||
constexpr int kHttpSuccess = 200;
|
||||
constexpr auto kPBProtocol = "PB";
|
||||
constexpr auto kFBSProtocol = "FBS";
|
||||
constexpr auto kAggregationKernelType = "Aggregation";
|
||||
constexpr auto kOptimizerKernelType = "Optimizer";
|
||||
constexpr auto kCtxFuncGraph = "FuncGraph";
|
||||
constexpr auto kCtxIterNum = "iteration";
|
||||
constexpr auto kCtxDeviceMetas = "device_metas";
|
||||
constexpr auto kCtxTotalTimeoutDuration = "total_timeout_duration";
|
||||
constexpr auto kCtxUpdateModelClientList = "update_model_client_list";
|
||||
constexpr auto kCtxUpdateModelClientNum = "update_model_client_num";
|
||||
|
||||
// This macro the current timestamp in milliseconds.
|
||||
#define CURRENT_TIME_MILLI \
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch())
|
||||
|
||||
#define RETURN_IF_NULL(expr, ret) \
|
||||
if (expr == nullptr) { \
|
||||
MS_LOG(ERROR) << #expr << " is nullptr."; \
|
||||
return ret; \
|
||||
}
|
||||
|
||||
// This method returns the size in bytes of the given TypeId.
|
||||
inline size_t GetTypeIdByte(const TypeId &type) {
|
||||
switch (type) {
|
||||
case kNumberTypeFloat16:
|
||||
return 2;
|
||||
case kNumberTypeUInt32:
|
||||
case kNumberTypeFloat32:
|
||||
return 4;
|
||||
case kNumberTypeUInt64:
|
||||
return 8;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "TypeId " << type << " not supported.";
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
inline AddressPtr GenerateParameterNodeAddrPtr(const CNodePtr &kernel_node, size_t param_idx) {
|
||||
RETURN_IF_NULL(kernel_node, nullptr);
|
||||
auto param_node =
|
||||
AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(kernel_node, param_idx), 0).first->cast<ParameterPtr>();
|
||||
RETURN_IF_NULL(param_node, nullptr);
|
||||
auto param_tensor = param_node->default_param()->cast<tensor::TensorPtr>();
|
||||
RETURN_IF_NULL(param_tensor, nullptr);
|
||||
AddressPtr addr = std::make_shared<kernel::Address>();
|
||||
addr->addr = param_tensor->data_c();
|
||||
addr->size = param_tensor->data().nbytes();
|
||||
return addr;
|
||||
}
|
||||
|
||||
// Definitions for Federated Learning.
|
||||
|
||||
// Definitions for Parameter Server.
|
||||
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_COMMON_H_
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/server/consistent_hash_ring.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
bool ConsistentHashRing::Insert(uint32_t rank) {
|
||||
std::string physical_node_hash_key = std::to_string(rank);
|
||||
for (uint32_t i = 0; i < virtual_node_num_; i++) {
|
||||
physical_node_hash_key += "#" + std::to_string(i);
|
||||
MS_LOG(DEBUG) << "Insert virtual node " << physical_node_hash_key << " for node " << rank;
|
||||
|
||||
size_t hash_value = std::hash<std::string>()(physical_node_hash_key);
|
||||
if (ring_.count(hash_value) != 0) {
|
||||
MS_LOG(WARNING) << "Virtual node " << physical_node_hash_key << " is already mapped to the ring.";
|
||||
continue;
|
||||
}
|
||||
ring_[hash_value] = rank;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ConsistentHashRing::Erase(uint32_t rank) {
|
||||
for (auto iterator = ring_.begin(); iterator != ring_.end();) {
|
||||
if (iterator->second == rank) {
|
||||
ring_.erase(iterator++);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
uint32_t ConsistentHashRing::Find(const std::string &key) {
|
||||
size_t hash_value = std::hash<std::string>()(key);
|
||||
auto iterator = ring_.lower_bound(hash_value);
|
||||
if (iterator == ring_.end()) {
|
||||
// If the virtual node is not found clockwise, the key will be mapped to the first virtual node on the ring.
|
||||
iterator = ring_.begin();
|
||||
}
|
||||
return iterator->second;
|
||||
}
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,64 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_CONSISTENT_HASH_RING_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_CONSISTENT_HASH_RING_H_
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
// To support distributed storage and make servers easy to scale-out and scale-in for a large load of metadata in
|
||||
// server, we use class ConsistentHashRing to help servers find out which metadata is stored in which server node.
|
||||
|
||||
// Class ConsistentHashRing implements the algorithm described in the paper
|
||||
// <https://dl.acm.org/doi/pdf/10.1145/258533.258660>.
|
||||
|
||||
// This class will create a ring for hash values of metadata and server nodes. Each server could use this ring to
|
||||
// retrieve data stored in other servers according to the hash keys. The time complexity for adding/deleting/searching
|
||||
// of this algorithm is basically O(log n).
|
||||
class ConsistentHashRing {
|
||||
public:
|
||||
// The parameter virtual_node_num for constructor means the virtual node number to be created for each physical server
|
||||
// node. According to the paper, these virtual nodes could help spread data to all the servers and ensuring balancing
|
||||
// at the same time. And when we say "adding/deleting/searching", we are talking about operations on thease virtual
|
||||
// nodes instead of the physical nodes.
|
||||
explicit ConsistentHashRing(uint32_t virtual_node_num = 128) : virtual_node_num_(virtual_node_num) {}
|
||||
~ConsistentHashRing() = default;
|
||||
|
||||
// Insert several virtual nodes for a server into this ring according to its rank id.
|
||||
bool Insert(uint32_t rank);
|
||||
|
||||
// Remove virtual nodes for a server according to its rank id.
|
||||
bool Erase(uint32_t rank);
|
||||
|
||||
// Find the physical server node's rank according to the metadata's key.
|
||||
uint32_t Find(const std::string &key);
|
||||
|
||||
private:
|
||||
uint32_t virtual_node_num_;
|
||||
// The hash ring for the server nodes.
|
||||
// Key is the hash value of the virtual node.
|
||||
// Value is the physical node' rank id.
|
||||
std::map<size_t, uint32_t> ring_;
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_CONSISTENT_HASH_RING_H_
|
|
@ -0,0 +1,315 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/server/executor.h"
|
||||
#include <set>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
void Executor::Init(const FuncGraphPtr &func_graph, size_t aggregation_count) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
if (aggregation_count == 0) {
|
||||
MS_LOG(EXCEPTION) << "Server aggregation count must be greater than 0";
|
||||
return;
|
||||
}
|
||||
aggregation_count_ = aggregation_count;
|
||||
|
||||
// Initialize each trainable parameter's aggregator, including memory register, aggregation algorithms and optimizers.
|
||||
bool ret = InitParamAggregator(func_graph);
|
||||
if (!ret) {
|
||||
MS_LOG(EXCEPTION) << "Initializing parameter aggregators failed.";
|
||||
return;
|
||||
}
|
||||
initialized_ = true;
|
||||
return;
|
||||
}
|
||||
|
||||
bool Executor::initialized() const { return initialized_; }
|
||||
|
||||
bool Executor::HandlePush(const std::string ¶m_name, const UploadData &upload_data) {
|
||||
MS_LOG(DEBUG) << "Do Push for parameter " << param_name;
|
||||
if (param_aggrs_.count(param_name) == 0) {
|
||||
MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server.";
|
||||
return false;
|
||||
}
|
||||
|
||||
std::mutex &mtx = parameter_mutex_[param_name];
|
||||
std::unique_lock<std::mutex> lock(mtx);
|
||||
auto ¶m_aggr = param_aggrs_[param_name];
|
||||
|
||||
// Push operation needs to wait until the pulling process is done.
|
||||
while (!param_aggr->IsPullingDone()) {
|
||||
lock.unlock();
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(5));
|
||||
lock.lock();
|
||||
}
|
||||
|
||||
// 1.Update data with the uploaded data of the worker.
|
||||
if (!param_aggr->UpdateData(upload_data)) {
|
||||
MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed.";
|
||||
return false;
|
||||
}
|
||||
// 2.Launch aggregation for this trainable parameter.
|
||||
if (!param_aggr->LaunchAggregators()) {
|
||||
MS_LOG(ERROR) << "Launching aggregators for parameter " << param_name << " failed.";
|
||||
return false;
|
||||
}
|
||||
if (param_aggr->IsAggregationDone()) {
|
||||
// 3.After the aggregation is done, optimize the trainable parameter.
|
||||
if (!param_aggr->LaunchOptimizers()) {
|
||||
MS_LOG(ERROR) << "Optimizing for parameter " << param_name << " failed.";
|
||||
return false;
|
||||
}
|
||||
// 4.Reset pulling and aggregation status after optimizing is done.
|
||||
param_aggr->ResetPullingStatus();
|
||||
param_aggr->ResetAggregationStatus();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Executor::HandleModelUpdate(const std::string ¶m_name, const UploadData &upload_data) {
|
||||
MS_LOG(DEBUG) << "Do UpdateModel for parameter " << param_name;
|
||||
if (param_aggrs_.count(param_name) == 0) {
|
||||
// The param_name could include some other parameters like momentum, but we don't think it's invalid. So here we
|
||||
// just print a warning log and return true.
|
||||
MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server.";
|
||||
return true;
|
||||
}
|
||||
|
||||
std::mutex &mtx = parameter_mutex_[param_name];
|
||||
std::unique_lock<std::mutex> lock(mtx);
|
||||
auto ¶m_aggr = param_aggrs_[param_name];
|
||||
|
||||
if (!param_aggr->UpdateData(upload_data)) {
|
||||
MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed.";
|
||||
return false;
|
||||
}
|
||||
// Different from Push, UpdateModel doesn't need to checkout the aggregation status.
|
||||
if (!param_aggr->LaunchAggregators()) {
|
||||
MS_LOG(ERROR) << "Launching aggregators for parameter " << param_name << " failed.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Executor::HandleModelUpdateAsync(const std::map<std::string, UploadData> &feature_map) {
|
||||
std::unique_lock<std::mutex> model_lock(model_mutex_);
|
||||
for (const auto &trainable_param : feature_map) {
|
||||
const std::string ¶m_name = trainable_param.first;
|
||||
if (param_aggrs_.count(param_name) == 0) {
|
||||
MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server.";
|
||||
continue;
|
||||
}
|
||||
|
||||
std::mutex &mtx = parameter_mutex_[param_name];
|
||||
std::unique_lock<std::mutex> lock(mtx);
|
||||
auto ¶m_aggr = param_aggrs_[param_name];
|
||||
|
||||
const UploadData &upload_data = trainable_param.second;
|
||||
if (!param_aggr->UpdateData(upload_data)) {
|
||||
MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed.";
|
||||
return false;
|
||||
}
|
||||
if (!param_aggr->LaunchAggregators()) {
|
||||
MS_LOG(ERROR) << "Launching aggregators for parameter " << param_name << " failed.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Executor::HandleOverwriteWeightsByKey(const std::map<std::string, Address> &feature_map) {
|
||||
for (const auto &trainable_param : feature_map) {
|
||||
const std::string ¶m_name = trainable_param.first;
|
||||
if (param_aggrs_.count(param_name) == 0) {
|
||||
MS_LOG(WARNING) << "Weight " << param_name << " is not registered in server.";
|
||||
continue;
|
||||
}
|
||||
|
||||
std::mutex &mtx = parameter_mutex_[param_name];
|
||||
std::unique_lock<std::mutex> lock(mtx);
|
||||
auto ¶m_aggr = param_aggrs_[param_name];
|
||||
|
||||
AddressPtr old_weight = param_aggr->GetWeight();
|
||||
if (old_weight == nullptr) {
|
||||
MS_LOG(ERROR) << "Get weight of " << param_name << " failed: the AddressPtr is nullptr.";
|
||||
return false;
|
||||
}
|
||||
|
||||
const Address &new_weight = trainable_param.second;
|
||||
if (new_weight.addr == nullptr) {
|
||||
MS_LOG(ERROR) << "The new weight is nullptr.";
|
||||
return false;
|
||||
}
|
||||
|
||||
int ret = memcpy_s(old_weight->addr, old_weight->size, new_weight.addr, new_weight.size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
AddressPtr Executor::HandlePull(const std::string ¶m_name) {
|
||||
MS_LOG(INFO) << "Handle blocking pull msg for parameter " << param_name;
|
||||
if (param_aggrs_.count(param_name) == 0) {
|
||||
MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::mutex &mtx = parameter_mutex_[param_name];
|
||||
std::unique_lock<std::mutex> lock(mtx);
|
||||
auto ¶m_aggr = param_aggrs_[param_name];
|
||||
|
||||
// Pulling must wait until the optimizing process is done.
|
||||
while (!param_aggr->IsOptimizingDone()) {
|
||||
lock.unlock();
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(5));
|
||||
lock.lock();
|
||||
}
|
||||
AddressPtr addr = param_aggr->Pull();
|
||||
// If this Pull is the last one, reset pulling and optimizing status.
|
||||
if (param_aggr->IsPullingDone()) {
|
||||
param_aggr->ResetOptimizingStatus();
|
||||
}
|
||||
return addr;
|
||||
}
|
||||
|
||||
std::map<std::string, AddressPtr> Executor::HandleAsyncGetModel() {
|
||||
std::unique_lock<std::mutex> lock(model_mutex_);
|
||||
return GetModel();
|
||||
}
|
||||
|
||||
std::map<std::string, AddressPtr> Executor::HandleGetWeightsByKey(const std::vector<std::string> ¶m_names) {
|
||||
std::map<std::string, AddressPtr> weights;
|
||||
for (const auto ¶m_name : param_names) {
|
||||
if (param_aggrs_.count(param_name) == 0) {
|
||||
MS_LOG(ERROR) << "Parameter " << param_name << " is not registered in server.";
|
||||
return weights;
|
||||
}
|
||||
|
||||
std::mutex &mtx = parameter_mutex_[param_name];
|
||||
std::unique_lock<std::mutex> lock(mtx);
|
||||
const auto ¶m_aggr = param_aggrs_[param_name];
|
||||
|
||||
AddressPtr addr = param_aggr->GetWeight();
|
||||
if (addr == nullptr) {
|
||||
MS_LOG(ERROR) << "Get weight of " << param_name << " failed: the AddressPtr is nullptr.";
|
||||
continue;
|
||||
}
|
||||
weights[param_name] = addr;
|
||||
}
|
||||
return weights;
|
||||
}
|
||||
|
||||
bool Executor::IsAllWeightAggregationDone() { return IsWeightAggrDone(param_names_); }
|
||||
|
||||
bool Executor::IsWeightAggrDone(const std::vector<std::string> ¶m_names) {
|
||||
for (const auto &name : param_names) {
|
||||
if (param_aggrs_.count(name) == 0) {
|
||||
MS_LOG(ERROR) << "Weight " << name << " is invalid in server.";
|
||||
return false;
|
||||
}
|
||||
|
||||
std::mutex &mtx = parameter_mutex_[name];
|
||||
std::unique_lock<std::mutex> lock(mtx);
|
||||
if (!param_aggrs_[name]->IsAggregationDone()) {
|
||||
MS_LOG(DEBUG) << "Update model for " << name << " is not done yet.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void Executor::ResetAggregationStatus() {
|
||||
for (const auto ¶m_name : param_names_) {
|
||||
std::mutex &mtx = parameter_mutex_[param_name];
|
||||
std::unique_lock<std::mutex> lock(mtx);
|
||||
param_aggrs_[param_name]->ResetAggregationStatus();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
std::map<std::string, AddressPtr> Executor::GetModel() {
|
||||
std::map<std::string, AddressPtr> model = {};
|
||||
for (const auto &name : param_names_) {
|
||||
std::mutex &mtx = parameter_mutex_[name];
|
||||
std::unique_lock<std::mutex> lock(mtx);
|
||||
AddressPtr addr = param_aggrs_[name]->GetWeight();
|
||||
if (addr == nullptr) {
|
||||
MS_LOG(WARNING) << "Get weight of " << name << " failed.";
|
||||
continue;
|
||||
}
|
||||
model[name] = addr;
|
||||
}
|
||||
return model;
|
||||
}
|
||||
|
||||
// bool Executor::Unmask() {
|
||||
// auto model = GetModel();
|
||||
// return mindarmour::CipherMgr::GetInstance().UnMask(model);
|
||||
// }
|
||||
|
||||
const std::vector<std::string> &Executor::param_names() const { return param_names_; }
|
||||
|
||||
std::string Executor::GetTrainableParamName(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
std::string cnode_name = AnfAlgo::GetCNodeName(cnode);
|
||||
if (kNameToIdxMap.count(cnode_name) == 0) {
|
||||
return "";
|
||||
}
|
||||
const OptimParamNameToIndex &index_info = kNameToIdxMap.at(cnode_name);
|
||||
size_t weight_idx = index_info.at("inputs").at(kWeight);
|
||||
AnfNodePtr weight_node = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cnode, weight_idx), 0).first;
|
||||
MS_EXCEPTION_IF_NULL(weight_node);
|
||||
if (!weight_node->isa<Parameter>()) {
|
||||
MS_LOG(EXCEPTION) << weight_idx << " input of " << cnode_name << " is not a Parameter.";
|
||||
}
|
||||
return weight_node->fullname_with_scope();
|
||||
}
|
||||
|
||||
bool Executor::InitParamAggregator(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
const auto &cnodes = func_graph->GetOrderedCnodes();
|
||||
for (const auto &cnode : cnodes) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
const std::string ¶m_name = GetTrainableParamName(cnode);
|
||||
if (param_name.empty()) {
|
||||
continue;
|
||||
}
|
||||
if (param_aggrs_.count(param_name) != 0) {
|
||||
MS_LOG(WARNING) << param_name << " already has its control flow.";
|
||||
continue;
|
||||
}
|
||||
|
||||
std::shared_ptr<ParameterAggregator> param_aggr = std::make_shared<ParameterAggregator>();
|
||||
MS_EXCEPTION_IF_NULL(param_aggr);
|
||||
param_names_.push_back(param_name);
|
||||
param_aggrs_[param_name] = param_aggr;
|
||||
parameter_mutex_[param_name];
|
||||
param_aggr->Init(cnode, aggregation_count_);
|
||||
MS_LOG(DEBUG) << "Initializing control flow for param_name " << param_name << " success.";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,125 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_EXECUTOR_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_EXECUTOR_H_
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <mutex>
|
||||
#include <condition_variable>
|
||||
#include "ps/server/common.h"
|
||||
#include "ps/server/parameter_aggregator.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
// Executor is the entrance for server to handle aggregation, optimizing, model querying, etc. It handles
|
||||
// logics relevant to kernel launching.
|
||||
class Executor {
|
||||
public:
|
||||
static Executor &GetInstance() {
|
||||
static Executor instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
// FuncGraphPtr func_graph is the graph compiled by the frontend. aggregation_count is the number which will
|
||||
// be used for aggregators.
|
||||
// As noted in header file parameter_aggregator.h, we create aggregators by trainable parameters, which is the
|
||||
// optimizer cnode's input. So we need to initialize server executor using func_graph.
|
||||
void Init(const FuncGraphPtr &func_graph, size_t aggregation_count);
|
||||
|
||||
// Called in parameter server training mode to do Push operation.
|
||||
// For the same trainable parameter, HandlePush method must be called aggregation_count_ times before it's considered
|
||||
// as completed.
|
||||
bool HandlePush(const std::string ¶m_name, const UploadData &upload_data);
|
||||
|
||||
// Called in parameter server training mode to do Pull operation.
|
||||
// Returns the value of parameter param_name.
|
||||
// HandlePull method must be called the same times as HandlePush is called before it's considered as
|
||||
// completed.
|
||||
AddressPtr HandlePull(const std::string ¶m_name);
|
||||
|
||||
// Called in federated learning training mode. Update value for parameter param_name.
|
||||
bool HandleModelUpdate(const std::string ¶m_name, const UploadData &upload_data);
|
||||
|
||||
// Called in asynchronous federated learning training mode. Update current model with the new feature map
|
||||
// asynchronously.
|
||||
bool HandleModelUpdateAsync(const std::map<std::string, UploadData> &feature_map);
|
||||
|
||||
// Called in asynchronous federated learning training mode. Returns whole model in key-value where key refers to the
|
||||
// parameter name.
|
||||
std::map<std::string, AddressPtr> HandleAsyncGetModel();
|
||||
|
||||
// Forcibly overwrite specific weights in overwriteWeights message.
|
||||
bool HandleOverwriteWeightsByKey(const std::map<std::string, Address> &feature_map);
|
||||
|
||||
// Returns value for multiple trainable parameters passed by weight_names.
|
||||
std::map<std::string, AddressPtr> HandleGetWeightsByKey(const std::vector<std::string> ¶m_names);
|
||||
|
||||
// Reset the aggregation status for all aggregation kernels in the server.
|
||||
void ResetAggregationStatus();
|
||||
|
||||
// Judge whether aggregation processes for all weights/gradients are completed.
|
||||
bool IsAllWeightAggregationDone();
|
||||
|
||||
// Judge whether the aggregation processes for the given param_names are completed.
|
||||
bool IsWeightAggrDone(const std::vector<std::string> ¶m_names);
|
||||
|
||||
// Returns whole model in key-value where key refers to the parameter name.
|
||||
std::map<std::string, AddressPtr> GetModel();
|
||||
|
||||
// Returns whether the executor singleton is already initialized.
|
||||
bool initialized() const;
|
||||
|
||||
const std::vector<std::string> ¶m_names() const;
|
||||
|
||||
private:
|
||||
Executor() {}
|
||||
~Executor() = default;
|
||||
Executor(const Executor &) = delete;
|
||||
Executor &operator=(const Executor &) = delete;
|
||||
|
||||
// Returns the trainable parameter name parsed from this cnode.
|
||||
std::string GetTrainableParamName(const CNodePtr &cnode);
|
||||
|
||||
// Server's graph is basically the same as Worker's graph, so we can get all information from func_graph for later
|
||||
// computations. Including forward and backward propagation, aggregation, optimizing, etc.
|
||||
bool InitParamAggregator(const FuncGraphPtr &func_graph);
|
||||
|
||||
bool initialized_;
|
||||
size_t aggregation_count_;
|
||||
std::vector<std::string> param_names_;
|
||||
|
||||
// The map for trainable parameter names and its ParameterAggregator, as noted in the header file
|
||||
// parameter_aggregator.h
|
||||
std::map<std::string, std::shared_ptr<ParameterAggregator>> param_aggrs_;
|
||||
|
||||
// The mutex ensures that the operation on whole model is threadsafe.
|
||||
// The whole model is constructed by all trainable parameters.
|
||||
std::mutex model_mutex_;
|
||||
|
||||
// Because ParameterAggregator is not threadsafe, we have to create mutex for each ParameterAggregator so we can
|
||||
// acquire lock before calling its method.
|
||||
std::map<std::string, std::mutex> parameter_mutex_;
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_EXECUTOR_H_
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/server/iteration_timer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
void IterationTimer::Start(const std::chrono::milliseconds &duration) {
|
||||
if (running_.load()) {
|
||||
MS_LOG(WARNING) << "The timer already started.";
|
||||
return;
|
||||
}
|
||||
running_ = true;
|
||||
end_time_ = CURRENT_TIME_MILLI + duration;
|
||||
monitor_thread_ = std::thread([&]() {
|
||||
while (running_.load()) {
|
||||
if (CURRENT_TIME_MILLI > end_time_) {
|
||||
timeout_callback_();
|
||||
running_ = false;
|
||||
}
|
||||
// The time tick is 1 millisecond.
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(1));
|
||||
}
|
||||
});
|
||||
monitor_thread_.detach();
|
||||
}
|
||||
|
||||
void IterationTimer::Stop() { running_ = false; }
|
||||
|
||||
void IterationTimer::SetTimeOutCallBack(const TimeOutCb &timeout_cb) {
|
||||
timeout_callback_ = timeout_cb;
|
||||
return;
|
||||
}
|
||||
|
||||
bool IterationTimer::IsTimeOut(const std::chrono::milliseconds ×tamp) {
|
||||
return timestamp > end_time_ ? true : false;
|
||||
}
|
||||
|
||||
bool IterationTimer::IsRunning() { return running_; }
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,64 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_ITERATION_TIMER_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_ITERATION_TIMER_H_
|
||||
|
||||
#include <chrono>
|
||||
#include <atomic>
|
||||
#include <thread>
|
||||
#include <functional>
|
||||
#include "ps/server/common.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
// IterationTimer controls the time window for the purpose of eliminating trailing time of each iteration.
|
||||
class IterationTimer {
|
||||
public:
|
||||
IterationTimer() : running_(false), end_time_(0) {}
|
||||
~IterationTimer() = default;
|
||||
|
||||
// Start timing. The timer will stop after parameter 'duration' milliseconds.
|
||||
void Start(const std::chrono::milliseconds &duration);
|
||||
|
||||
// Caller could use this method to manually stop timing, otherwise the timer will keep timing until it expires.
|
||||
void Stop();
|
||||
|
||||
// Set the callback which will be called when the timer expires.
|
||||
void SetTimeOutCallBack(const TimeOutCb &timeout_cb);
|
||||
|
||||
// Judge whether current timestamp is out of time window's range since the Start function is called.
|
||||
bool IsTimeOut(const std::chrono::milliseconds ×tamp);
|
||||
|
||||
// Judge whether the timer is keeping timing.
|
||||
bool IsRunning();
|
||||
|
||||
private:
|
||||
// The running state for the timer.
|
||||
std::atomic<bool> running_;
|
||||
|
||||
// The timestamp in millesecond at which the timer should stop timing.
|
||||
std::chrono::milliseconds end_time_;
|
||||
|
||||
// The thread that keeps timing and call timeout_callback_ when the timer expires.
|
||||
std::thread monitor_thread_;
|
||||
TimeOutCb timeout_callback_;
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_ITERATION_TIMER_H_
|
|
@ -0,0 +1,95 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "ps/server/common.h"
|
||||
#include "ps/server/memory_register.h"
|
||||
#include "ps/server/kernel/params_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
// AggregationKernel is the kernel for weight, grad or other kinds of parameters' aggregation.
|
||||
// For example, dense gradients accumulation, federated average, etc.
|
||||
// Normally the aggregation process in AggregationKernel is like a finite-state machine:
|
||||
// Initial->Aggregating->Aggregation done->Initial.
|
||||
class AggregationKernel : public CPUKernel {
|
||||
public:
|
||||
AggregationKernel() : name_(""), done_(false), done_count_(0), accum_count_(0) {}
|
||||
virtual ~AggregationKernel() = default;
|
||||
|
||||
// InitKernel and Launch methods are inherited from pure virtual function of CPUKernel so it must have implementation.
|
||||
virtual void InitKernel(const CNodePtr &kernel_node) {}
|
||||
virtual bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Server kernel's memory allocation method, which is different from the workflow in
|
||||
// Session(GPUSession/CPUSession/AscendSession).
|
||||
// virtual void AssignMemory(const CNodePtr &kernel_node, std::shared_ptr<MemoryRegister> memory_register) = 0;
|
||||
|
||||
// Set the cumulative count this aggregation kernel needs before aggregation is done.
|
||||
void set_done_count(size_t count) { done_count_ = count; }
|
||||
|
||||
// So we use Reset to set the finite-state machine state to Initial after considering this round of aggregation is
|
||||
// done.
|
||||
virtual void Reset() = 0;
|
||||
|
||||
virtual bool IsAggregationDone() = 0;
|
||||
|
||||
// Setter and getter of kernels parameters information.
|
||||
void set_params_info(const ParamsInfo ¶ms_info) { params_info_ = params_info; }
|
||||
const std::vector<std::string> &input_names() { return params_info_.inputs_names(); }
|
||||
const std::vector<std::string> &workspace_names() { return params_info_.workspace_names(); }
|
||||
const std::vector<std::string> &output_names() { return params_info_.outputs_names(); }
|
||||
|
||||
// Returns information about whether some inputs should reuse kernel node inputs memory.
|
||||
const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info() { return reuse_kernel_node_inputs_info_; }
|
||||
|
||||
protected:
|
||||
virtual void GenerateReuseKernelNodeInfo() = 0;
|
||||
// Aggregation kernel's name which is set by kernel register function.
|
||||
std::string name_;
|
||||
|
||||
// The aggregation is considered done after done_count_ times of accumulation.
|
||||
bool done_;
|
||||
|
||||
// Cumulative count this aggregation kernel needs before aggregation is done.
|
||||
size_t done_count_;
|
||||
|
||||
// Current cumulative count.
|
||||
size_t accum_count_;
|
||||
|
||||
// Parameters information used for kernel register, memory assignment, etc.
|
||||
ParamsInfo params_info_;
|
||||
|
||||
// Information about server kernel reusing kernel node inputs memory from the front end.
|
||||
// Key refers to the server kernel's input index. Value refers to the kernel node's input index.
|
||||
ReuseKernelNodeInfo reuse_kernel_node_inputs_info_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_H_
|
|
@ -0,0 +1,71 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/server/kernel/aggregation_kernel_factory.h"
|
||||
#include <utility>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
bool AggregationKernelFactory::Matched(const ParamsInfo ¶ms_info, const CNodePtr &kernel_node) {
|
||||
std::string cnode_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
if (kNameToIdxMap.count(cnode_name) == 0) {
|
||||
MS_LOG(ERROR) << "Can't find index info for kernel " << cnode_name;
|
||||
return false;
|
||||
}
|
||||
|
||||
auto input_name_to_idx = kNameToIdxMap.at(cnode_name).at("inputs");
|
||||
size_t input_num = params_info.inputs_num();
|
||||
for (size_t i = 0; i < input_num; i++) {
|
||||
auto one_input_name_type = params_info.inputs_name_type(i);
|
||||
std::string name = one_input_name_type.first;
|
||||
if (input_name_to_idx.count(name) == 0) {
|
||||
MS_LOG(DEBUG) << cnode_name << " does not have input named " << name
|
||||
<< ". This is the new input for this aggregation kernel.";
|
||||
continue;
|
||||
}
|
||||
size_t input_idx = input_name_to_idx.at(name);
|
||||
TypeId kernel_node_input_type = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_idx);
|
||||
TypeId registered_input_type = one_input_name_type.second;
|
||||
if (registered_input_type != kernel_node_input_type) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
auto output_name_to_idx = kNameToIdxMap.at(cnode_name).at("outputs");
|
||||
size_t output_num = params_info.outputs_num();
|
||||
for (size_t i = 0; i < output_num; i++) {
|
||||
auto one_output_name_type = params_info.outputs_name_type(i);
|
||||
std::string name = one_output_name_type.first;
|
||||
if (output_name_to_idx.count(name) == 0) {
|
||||
MS_LOG(DEBUG) << cnode_name << " does not have output named " << name
|
||||
<< ". This is the new output for this aggregation kernel.";
|
||||
continue;
|
||||
}
|
||||
size_t output_idx = output_name_to_idx.at(name);
|
||||
TypeId kernel_node_output_type = AnfAlgo::GetOutputInferDataType(kernel_node, output_idx);
|
||||
TypeId registered_output_type = one_output_name_type.second;
|
||||
if (registered_output_type != kernel_node_output_type) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,71 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_FACTORY_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_FACTORY_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include "ps/server/kernel/kernel_factory.h"
|
||||
#include "ps/server/kernel/aggregation_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
using AggregationKernelCreator = std::function<std::shared_ptr<AggregationKernel>()>;
|
||||
class AggregationKernelFactory : public KernelFactory<std::shared_ptr<AggregationKernel>, AggregationKernelCreator> {
|
||||
public:
|
||||
static AggregationKernelFactory &GetInstance() {
|
||||
static AggregationKernelFactory instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
private:
|
||||
AggregationKernelFactory() = default;
|
||||
~AggregationKernelFactory() override = default;
|
||||
AggregationKernelFactory(const AggregationKernelFactory &) = delete;
|
||||
AggregationKernelFactory &operator=(const AggregationKernelFactory &) = delete;
|
||||
|
||||
// Judge whether the server aggregation kernel can be created according to registered ParamsInfo.
|
||||
bool Matched(const ParamsInfo ¶ms_info, const CNodePtr &kernel_node) override;
|
||||
};
|
||||
|
||||
class AggregationKernelRegister {
|
||||
public:
|
||||
AggregationKernelRegister(const std::string &name, const ParamsInfo ¶ms_info,
|
||||
AggregationKernelCreator &&creator) {
|
||||
AggregationKernelFactory::GetInstance().Register(name, params_info, std::move(creator));
|
||||
}
|
||||
};
|
||||
|
||||
// Register aggregation kernel with one template type T.
|
||||
#define REG_AGGREGATION_KERNEL(NAME, PARAMS_INFO, CLASS, T) \
|
||||
static_assert(std::is_base_of<AggregationKernel, CLASS<T>>::value, " must be base of AggregationKernel"); \
|
||||
static const AggregationKernelRegister g_##NAME##_##T##_aggregation_kernel_reg( \
|
||||
#NAME, PARAMS_INFO, []() { return std::make_shared<CLASS<T>>(); });
|
||||
|
||||
// Register aggregation kernel with two template types: T and S.
|
||||
#define REG_AGGREGATION_KERNEL_TWO(NAME, PARAMS_INFO, CLASS, T, S) \
|
||||
static_assert(std::is_base_of<AggregationKernel, CLASS<T, S>>::value, " must be base of AggregationKernel"); \
|
||||
static const AggregationKernelRegister g_##NAME##_##T##_##S##_aggregation_kernel_reg( \
|
||||
#NAME, PARAMS_INFO, []() { return std::make_shared<CLASS<T, S>>(); });
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_FACTORY_H_
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/server/kernel/apply_momentum_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
REG_OPTIMIZER_KERNEL(ApplyMomentum,
|
||||
ParamsInfo()
|
||||
.AddInputNameType(kWeight, kNumberTypeFloat32)
|
||||
.AddInputNameType(kAccumulation, kNumberTypeFloat32)
|
||||
.AddInputNameType(kLearningRate, kNumberTypeFloat32)
|
||||
.AddInputNameType(kGradient, kNumberTypeFloat32)
|
||||
.AddInputNameType(kMomentum, kNumberTypeFloat32),
|
||||
ApplyMomentumKernel, float)
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,61 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_APPLY_MOMENTUM_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_APPLY_MOMENTUM_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "backend/kernel_compiler/cpu/apply_momentum_cpu_kernel.h"
|
||||
#include "ps/server/kernel/optimizer_kernel.h"
|
||||
#include "ps/server/kernel/optimizer_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
using mindspore::kernel::ApplyMomentumCPUKernel;
|
||||
template <typename T>
|
||||
class ApplyMomentumKernel : public ApplyMomentumCPUKernel, public OptimizerKernel {
|
||||
public:
|
||||
ApplyMomentumKernel() = default;
|
||||
~ApplyMomentumKernel() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &cnode) override {
|
||||
ApplyMomentumCPUKernel::InitKernel(cnode);
|
||||
InitServerKernelInputOutputSize(cnode);
|
||||
GenerateReuseKernelNodeInfo();
|
||||
}
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return ApplyMomentumCPUKernel::Launch(inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
void GenerateReuseKernelNodeInfo() override {
|
||||
reuse_kernel_node_inputs_info_.insert(std::make_pair(kWeight, 0));
|
||||
reuse_kernel_node_inputs_info_.insert(std::make_pair(kAccumulation, 1));
|
||||
reuse_kernel_node_inputs_info_.insert(std::make_pair(kLearningRate, 2));
|
||||
reuse_kernel_node_inputs_info_.insert(std::make_pair(kMomentum, 4));
|
||||
return;
|
||||
}
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_APPLY_MOMENTUM_KERNEL_H_
|
|
@ -0,0 +1,30 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/server/kernel/dense_grad_accum_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
REG_AGGREGATION_KERNEL(
|
||||
DenseGradAccum,
|
||||
ParamsInfo().AddInputNameType(kGradient, kNumberTypeFloat32).AddInputNameType(kNewGradient, kNumberTypeFloat32),
|
||||
DenseGradAccumKernel, float)
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,95 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_DENSE_GRAD_ACCUM_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_DENSE_GRAD_ACCUM_KERNEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "ps/server/kernel/aggregation_kernel.h"
|
||||
#include "ps/server/kernel/aggregation_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class DenseGradAccumKernel : public AggregationKernel {
|
||||
public:
|
||||
DenseGradAccumKernel() = default;
|
||||
~DenseGradAccumKernel() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
std::string cnode_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
if (kNameToIdxMap.count(cnode_name) == 0 || kNameToIdxMap.at(cnode_name).count("inputs") == 0 ||
|
||||
kNameToIdxMap.at(cnode_name).at("inputs").count("grad") == 0) {
|
||||
MS_LOG(EXCEPTION) << "Can't find index info of grad for kernel " << cnode_name;
|
||||
return;
|
||||
}
|
||||
size_t cnode_grad_idx = kNameToIdxMap.at(cnode_name).at("inputs").at("grad");
|
||||
std::vector<size_t> grad_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, cnode_grad_idx);
|
||||
size_t grad_size = std::accumulate(grad_shape.begin(), grad_shape.end(), sizeof(T), std::multiplies<size_t>());
|
||||
input_size_list_.push_back(grad_size);
|
||||
size_t new_grad_size = grad_size;
|
||||
input_size_list_.push_back(new_grad_size);
|
||||
GenerateReuseKernelNodeInfo();
|
||||
return;
|
||||
}
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
if (accum_count_ == 0) {
|
||||
int ret = memset_s(inputs[0]->addr, inputs[0]->size, 0x00, inputs[0]->size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "memset_s error, errorno(" << ret << ")";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
T *grad_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
T *new_grad_addr = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
for (size_t i = 0; i < inputs[0]->size / sizeof(T); i++) {
|
||||
grad_addr[i] += new_grad_addr[i];
|
||||
}
|
||||
|
||||
accum_count_++;
|
||||
if (accum_count_ > done_count_) {
|
||||
MS_LOG(ERROR) << "accum_count_ should not be greater than done_count_ " << done_count_;
|
||||
return false;
|
||||
}
|
||||
if (accum_count_ == done_count_) {
|
||||
for (size_t i = 0; i < inputs[0]->size / sizeof(T); i++) {
|
||||
grad_addr[i] /= done_count_;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void Reset() { accum_count_ = 0; }
|
||||
|
||||
bool IsAggregationDone() { return accum_count_ >= done_count_; }
|
||||
|
||||
void GenerateReuseKernelNodeInfo() override { return; }
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_DENSE_GRAD_ACCUM_KERNEL_H_
|
|
@ -0,0 +1,92 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_KERNEL_FACTORY_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_KERNEL_FACTORY_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <unordered_map>
|
||||
#include "ps/server/common.h"
|
||||
#include "ps/server/kernel/params_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
// KernelFactory is used to select and build kernels in server. It's the base class of OptimizerKernelFactory
|
||||
// and AggregationKernelFactory.
|
||||
|
||||
// Unlike normal MindSpore operator kernels, the server defines multiple types of kernels. For example: Aggregation
|
||||
// Kernel, Optimizer Kernel, Forward Kernel, etc. So we define KernelFactory as a template class for register of all
|
||||
// types of kernels.
|
||||
|
||||
// Because most information we need to create a server kernel is in func_graph passed by the front end, we create a
|
||||
// server kernel based on a cnode.
|
||||
|
||||
// Typename K refers to the shared_ptr of the kernel type.
|
||||
// Typename C refers to the creator function of the kernel.
|
||||
template <typename K, typename C>
|
||||
class KernelFactory {
|
||||
public:
|
||||
KernelFactory() = default;
|
||||
virtual ~KernelFactory() = default;
|
||||
|
||||
static KernelFactory &GetInstance() {
|
||||
static KernelFactory instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
// Kernels are registered by parameter information and its creator(constructor).
|
||||
void Register(const std::string &name, const ParamsInfo ¶ms_info, C &&creator) {
|
||||
name_to_creator_map_[name].push_back(std::make_pair(params_info, creator));
|
||||
}
|
||||
|
||||
// The kernels in server are created from func_graph's kernel_node passed by the front end.
|
||||
K Create(const std::string &name, const CNodePtr &kernel_node) {
|
||||
if (name_to_creator_map_.count(name) == 0) {
|
||||
MS_LOG(ERROR) << "Creating kernel failed: " << name << " is not registered.";
|
||||
}
|
||||
for (const auto &name_type_creator : name_to_creator_map_[name]) {
|
||||
const ParamsInfo ¶ms_info = name_type_creator.first;
|
||||
const C &creator = name_type_creator.second;
|
||||
if (Matched(params_info, kernel_node)) {
|
||||
auto kernel = creator();
|
||||
kernel->set_params_info(params_info);
|
||||
return kernel;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
KernelFactory(const KernelFactory &) = delete;
|
||||
KernelFactory &operator=(const KernelFactory &) = delete;
|
||||
|
||||
// Judge whether the server kernel can be created according to registered ParamsInfo.
|
||||
virtual bool Matched(const ParamsInfo ¶ms_info, const CNodePtr &kernel_node) { return true; }
|
||||
|
||||
// Generally, a server kernel can correspond to several ParamsInfo which is registered by the method 'Register' in
|
||||
// server kernel's *.cc files.
|
||||
std::unordered_map<std::string, std::vector<std::pair<ParamsInfo, C>>> name_to_creator_map_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_KERNEL_FACTORY_H_
|
|
@ -0,0 +1,97 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "ps/server/common.h"
|
||||
#include "ps/server/memory_register.h"
|
||||
#include "ps/server/kernel/params_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
using mindspore::kernel::IsSameShape;
|
||||
using mindspore::kernel::USE_NESTEROV;
|
||||
|
||||
// OptimizerKernel is the kernel in server for weights' optimizing.
|
||||
// Normally server's optimizer kernels should be inherited from CPU's optimzier kernels to reuse the implementation.
|
||||
class OptimizerKernel : public CPUKernel {
|
||||
public:
|
||||
OptimizerKernel() = default;
|
||||
virtual ~OptimizerKernel() = default;
|
||||
|
||||
// InitKernel and Launch methods are inherited from pure virtual function of CPUKernel so it must have implementation.
|
||||
virtual void InitKernel(const CNodePtr &kernel_node) {}
|
||||
virtual bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Server kernel's memory allocation method, which is different from the workflow in
|
||||
// Session(GPUSession/CPUSession/AscendSession).
|
||||
// virtual void AssignMemory(const CNodePtr &kernel_node, std::shared_ptr<MemoryRegister> memory_register) = 0;
|
||||
|
||||
// Setter and getter of kernels parameters information.
|
||||
void set_params_info(const ParamsInfo ¶ms_info) { params_info_ = params_info; }
|
||||
const std::vector<std::string> &input_names() { return params_info_.inputs_names(); }
|
||||
const std::vector<std::string> &workspace_names() { return params_info_.workspace_names(); }
|
||||
const std::vector<std::string> &output_names() { return params_info_.outputs_names(); }
|
||||
|
||||
// Returns information about whether some inputs should reuse kernel node inputs memory.
|
||||
const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info() { return reuse_kernel_node_inputs_info_; }
|
||||
|
||||
protected:
|
||||
virtual void GenerateReuseKernelNodeInfo() = 0;
|
||||
|
||||
void InitServerKernelInputOutputSize(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
size_t type_size = sizeof(float);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
std::vector<size_t> shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, input_index);
|
||||
size_t tensor_size =
|
||||
shape.empty() ? type_size : std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>());
|
||||
input_size_list_.emplace_back(tensor_size);
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||
std::vector<size_t> shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, output_index);
|
||||
size_t tensor_size =
|
||||
shape.empty() ? type_size : std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>());
|
||||
output_size_list_.emplace_back(tensor_size);
|
||||
}
|
||||
}
|
||||
|
||||
// Parameters information used for kernel register, memory assignment, etc.
|
||||
ParamsInfo params_info_;
|
||||
|
||||
// Information about server kernel reusing kernel node inputs memory from the front end.
|
||||
// Key refers to the server kernel's input index. Value refers to the kernel node's input index.
|
||||
ReuseKernelNodeInfo reuse_kernel_node_inputs_info_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_H_
|
|
@ -0,0 +1,70 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/server/kernel/optimizer_kernel_factory.h"
|
||||
#include <utility>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
bool OptimizerKernelFactory::Matched(const ParamsInfo ¶ms_info, const CNodePtr &kernel_node) {
|
||||
std::string cnode_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
if (kNameToIdxMap.count(cnode_name) == 0) {
|
||||
MS_LOG(ERROR) << "Can't find index info for kernel " << cnode_name;
|
||||
return false;
|
||||
}
|
||||
|
||||
auto input_name_to_idx = kNameToIdxMap.at(cnode_name).at("inputs");
|
||||
size_t input_num = params_info.inputs_num();
|
||||
for (size_t i = 0; i < input_num; i++) {
|
||||
auto one_input_name_type = params_info.inputs_name_type(i);
|
||||
std::string name = one_input_name_type.first;
|
||||
if (input_name_to_idx.count(name) == 0) {
|
||||
MS_LOG(EXCEPTION) << cnode_name << " does not have input named " << name;
|
||||
return false;
|
||||
}
|
||||
size_t input_idx = input_name_to_idx.at(name);
|
||||
TypeId kernel_node_input_type = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_idx);
|
||||
TypeId registered_input_type = one_input_name_type.second;
|
||||
if (registered_input_type != kernel_node_input_type) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
auto output_name_to_idx = kNameToIdxMap.at(cnode_name).at("outputs");
|
||||
size_t output_num = params_info.outputs_num();
|
||||
for (size_t i = 0; i < output_num; i++) {
|
||||
auto one_output_name_type = params_info.outputs_name_type(i);
|
||||
std::string name = one_output_name_type.first;
|
||||
if (output_name_to_idx.count(name) == 0) {
|
||||
MS_LOG(EXCEPTION) << cnode_name << " does not have output named " << name;
|
||||
return false;
|
||||
}
|
||||
size_t output_idx = output_name_to_idx.at(name);
|
||||
TypeId kernel_node_output_type = AnfAlgo::GetOutputInferDataType(kernel_node, output_idx);
|
||||
TypeId registered_output_type = one_output_name_type.second;
|
||||
if (registered_output_type != kernel_node_output_type) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,64 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_FACTORY_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_FACTORY_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include "ps/server/kernel/kernel_factory.h"
|
||||
#include "ps/server/kernel/optimizer_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
using OptimizerKernelCreator = std::function<std::shared_ptr<OptimizerKernel>()>;
|
||||
class OptimizerKernelFactory : public KernelFactory<std::shared_ptr<OptimizerKernel>, OptimizerKernelCreator> {
|
||||
public:
|
||||
static OptimizerKernelFactory &GetInstance() {
|
||||
static OptimizerKernelFactory instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
private:
|
||||
OptimizerKernelFactory() = default;
|
||||
~OptimizerKernelFactory() override = default;
|
||||
OptimizerKernelFactory(const OptimizerKernelFactory &) = delete;
|
||||
OptimizerKernelFactory &operator=(const OptimizerKernelFactory &) = delete;
|
||||
|
||||
// Judge whether the server optimizer kernel can be created according to registered ParamsInfo.
|
||||
bool Matched(const ParamsInfo ¶ms_info, const CNodePtr &kernel_node) override;
|
||||
};
|
||||
|
||||
class OptimizerKernelRegister {
|
||||
public:
|
||||
OptimizerKernelRegister(const std::string &name, const ParamsInfo ¶ms_info, OptimizerKernelCreator &&creator) {
|
||||
OptimizerKernelFactory::GetInstance().Register(name, params_info, std::move(creator));
|
||||
}
|
||||
};
|
||||
|
||||
// Register optimizer kernel with one template type T.
|
||||
#define REG_OPTIMIZER_KERNEL(NAME, PARAMS_INFO, CLASS, T) \
|
||||
static_assert(std::is_base_of<OptimizerKernel, CLASS<T>>::value, " must be base of OptimizerKernel"); \
|
||||
static const OptimizerKernelRegister g_##NAME##_##T##_optimizer_kernel_reg( \
|
||||
#NAME, PARAMS_INFO, []() { return std::make_shared<CLASS<T>>(); });
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_FACTORY_H_
|
|
@ -0,0 +1,68 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/server/kernel/params_info.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
ParamsInfo &ParamsInfo::AddInputNameType(const std::string &name, TypeId type) {
|
||||
inputs_name_type_.push_back(std::make_pair(name, type));
|
||||
inputs_names_.push_back(name);
|
||||
return *this;
|
||||
}
|
||||
|
||||
ParamsInfo &ParamsInfo::AddWorkspaceNameType(const std::string &name, TypeId type) {
|
||||
workspaces_name_type_.push_back(std::make_pair(name, type));
|
||||
workspace_names_.push_back(name);
|
||||
return *this;
|
||||
}
|
||||
|
||||
ParamsInfo &ParamsInfo::AddOutputNameType(const std::string &name, TypeId type) {
|
||||
outputs_name_type_.push_back(std::make_pair(name, type));
|
||||
outputs_names_.push_back(name);
|
||||
return *this;
|
||||
}
|
||||
|
||||
size_t ParamsInfo::inputs_num() const { return inputs_name_type_.size(); }
|
||||
|
||||
size_t ParamsInfo::outputs_num() const { return outputs_name_type_.size(); }
|
||||
|
||||
const std::pair<std::string, TypeId> &ParamsInfo::inputs_name_type(size_t index) const {
|
||||
if (index >= inputs_name_type_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Index " << index << " is out of bound of inputs_name_type_.";
|
||||
}
|
||||
return inputs_name_type_[index];
|
||||
}
|
||||
|
||||
const std::pair<std::string, TypeId> &ParamsInfo::outputs_name_type(size_t index) const {
|
||||
if (index >= outputs_name_type_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Index " << index << " is out of bound of outputs_name_type_.";
|
||||
}
|
||||
return outputs_name_type_[index];
|
||||
}
|
||||
|
||||
const std::vector<std::string> &ParamsInfo::inputs_names() const { return inputs_names_; }
|
||||
|
||||
const std::vector<std::string> &ParamsInfo::workspace_names() const { return workspace_names_; }
|
||||
|
||||
const std::vector<std::string> &ParamsInfo::outputs_names() const { return outputs_names_; }
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,70 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_PARAMS_INFO_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_PARAMS_INFO_H_
|
||||
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "ir/dtype/type_id.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
// ParamsInfo is used for server computation kernel's register, e.g, ApplyMomentumKernel, FedAvgKernel, etc.
|
||||
// Register of a server kernel needs every inputs/workspace/outputs parameters' name and type.
|
||||
// For example:
|
||||
// ParamsInfo()
|
||||
// .AddInputNameType("input1_name", kNumberTypeFloat32)
|
||||
// .AddInputNameType("input2_name", kNumberTypeUInt64)
|
||||
// .AddWorkspaceNameType("workspace1_name", kNumberTypeFloat32)
|
||||
// .AddOutputNameType("output1_name", kNumberTypeUInt64)
|
||||
// This invocation describes a server kernel with parameters below:
|
||||
// An input with name "input1_name" and type float32.
|
||||
// An input with name "input1_name" and type uint_64.
|
||||
// A workspace with name "workspace1_name" and type float32.
|
||||
// An output with name "output1_name" and type float32.
|
||||
class ParamsInfo {
|
||||
public:
|
||||
ParamsInfo() = default;
|
||||
~ParamsInfo() = default;
|
||||
|
||||
ParamsInfo &AddInputNameType(const std::string &name, TypeId type);
|
||||
ParamsInfo &AddWorkspaceNameType(const std::string &name, TypeId type);
|
||||
ParamsInfo &AddOutputNameType(const std::string &name, TypeId type);
|
||||
size_t inputs_num() const;
|
||||
size_t outputs_num() const;
|
||||
const std::pair<std::string, TypeId> &inputs_name_type(size_t index) const;
|
||||
const std::pair<std::string, TypeId> &outputs_name_type(size_t index) const;
|
||||
const std::vector<std::string> &inputs_names() const;
|
||||
const std::vector<std::string> &workspace_names() const;
|
||||
const std::vector<std::string> &outputs_names() const;
|
||||
|
||||
private:
|
||||
std::vector<std::pair<std::string, TypeId>> inputs_name_type_;
|
||||
std::vector<std::pair<std::string, TypeId>> workspaces_name_type_;
|
||||
std::vector<std::pair<std::string, TypeId>> outputs_name_type_;
|
||||
std::vector<std::string> inputs_names_;
|
||||
std::vector<std::string> workspace_names_;
|
||||
std::vector<std::string> outputs_names_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_PARAMS_INFO_H_
|
|
@ -0,0 +1,46 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/server/local_meta_storage.h"
|
||||
#include <string>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
void LocalMetaStorage::remove_value(const std::string &name) {
|
||||
std::unique_lock<std::mutex> lock(mtx_);
|
||||
if (key_to_meta_.count(name) != 0) {
|
||||
key_to_meta_.erase(key_to_meta_.find(name));
|
||||
}
|
||||
}
|
||||
|
||||
bool LocalMetaStorage::has_value(const std::string &name) {
|
||||
std::unique_lock<std::mutex> lock(mtx_);
|
||||
return key_to_meta_.count(name) != 0;
|
||||
}
|
||||
|
||||
void LocalMetaStorage::set_curr_iter_num(size_t num) {
|
||||
std::unique_lock<std::mutex> lock(mtx_);
|
||||
curr_iter_num_ = num;
|
||||
}
|
||||
|
||||
const size_t LocalMetaStorage::curr_iter_num() {
|
||||
std::unique_lock<std::mutex> lock(mtx_);
|
||||
return curr_iter_num_;
|
||||
}
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,88 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORAGE_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORAGE_H_
|
||||
|
||||
#include <any>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include "ps/server/common.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
// LocalMetaStorage class is used for metadata storage of this server process.
|
||||
// For example, the current iteration number, time windows for round kernels, etc.
|
||||
// LocalMetaStorage is threadsafe.
|
||||
class LocalMetaStorage {
|
||||
public:
|
||||
static LocalMetaStorage &GetInstance() {
|
||||
static LocalMetaStorage instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void put_value(const std::string &name, const T &value) {
|
||||
std::unique_lock<std::mutex> lock(mtx_);
|
||||
key_to_meta_[name] = value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const T &value(const std::string &name) {
|
||||
std::unique_lock<std::mutex> lock(mtx_);
|
||||
try {
|
||||
T value = std::any_cast<T>(key_to_meta_[name]);
|
||||
return value;
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(EXCEPTION) << "Value of " << name << " is not set.";
|
||||
}
|
||||
}
|
||||
|
||||
// This method returns a reference so that user can change this value without calling put_value.
|
||||
template <typename T>
|
||||
T &mutable_value(const std::string &name) {
|
||||
std::unique_lock<std::mutex> lock(mtx_);
|
||||
try {
|
||||
return std::any_cast<T &>(key_to_meta_[name]);
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(EXCEPTION) << "Value of " << name << " is not set.";
|
||||
}
|
||||
}
|
||||
|
||||
void remove_value(const std::string &name);
|
||||
bool has_value(const std::string &name);
|
||||
|
||||
void set_curr_iter_num(size_t num);
|
||||
const size_t curr_iter_num();
|
||||
|
||||
private:
|
||||
LocalMetaStorage() = default;
|
||||
~LocalMetaStorage() = default;
|
||||
LocalMetaStorage(const LocalMetaStorage &) = delete;
|
||||
LocalMetaStorage &operator=(const LocalMetaStorage &) = delete;
|
||||
|
||||
// key_to_meta_ stores metadata with key-value format.
|
||||
std::unordered_map<std::string, std::any> key_to_meta_;
|
||||
// This mutex makes sure that the operations on key_to_meta_ is threadsafe.
|
||||
std::mutex mtx_;
|
||||
size_t curr_iter_num_;
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORAGE_H_
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/server/memory_register.h"
|
||||
#include <utility>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
void MemoryRegister::RegisterAddressPtr(const std::string &name, const AddressPtr &address) {
|
||||
addresses_.try_emplace(name, address);
|
||||
}
|
||||
|
||||
void MemoryRegister::StoreFloatArray(std::unique_ptr<float[]> *array) { float_arrays_.push_back(std::move(*array)); }
|
||||
|
||||
void MemoryRegister::StoreInt32Array(std::unique_ptr<int[]> *array) { int32_arrays_.push_back(std::move(*array)); }
|
||||
|
||||
void MemoryRegister::StoreUint64Array(std::unique_ptr<size_t[]> *array) { uint64_arrays_.push_back(std::move(*array)); }
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,88 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_MEMORY_REGISTER_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_MEMORY_REGISTER_H_
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <typeinfo>
|
||||
#include "ps/server/common.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
// Memory allocated in server is normally trainable parameters, hyperparameters, gradients, etc.
|
||||
// MemoryRegister registers the Memory with key-value format where key refers to address's name("grad", "weights",
|
||||
// etc) and value is AddressPtr.
|
||||
class MemoryRegister {
|
||||
public:
|
||||
MemoryRegister() = default;
|
||||
~MemoryRegister() = default;
|
||||
|
||||
std::map<std::string, AddressPtr> &addresses() { return addresses_; }
|
||||
void RegisterAddressPtr(const std::string &name, const AddressPtr &address);
|
||||
|
||||
// In some cases, memory is passed by unique_ptr which is allocated by caller. They needs to be stored as well to
|
||||
// avoid its data being released.
|
||||
template <typename T>
|
||||
void RegisterArray(const std::string &name, std::unique_ptr<T[]> *array, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(array);
|
||||
void *data = array->get();
|
||||
AddressPtr addr = std::make_shared<Address>();
|
||||
addr->addr = data;
|
||||
addr->size = size;
|
||||
|
||||
if (typeid(T) == typeid(int)) {
|
||||
auto int_arr = CastUniquePtr<int, T>(array);
|
||||
StoreInt32Array(&int_arr);
|
||||
} else if (typeid(T) == typeid(float)) {
|
||||
auto float_arr = CastUniquePtr<float, T>(array);
|
||||
StoreFloatArray(&float_arr);
|
||||
} else if (typeid(T) == typeid(size_t)) {
|
||||
auto uint64_arr = CastUniquePtr<size_t, T>(array);
|
||||
StoreUint64Array(&uint64_arr);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "MemoryRegister does not support type " << typeid(T).name();
|
||||
return;
|
||||
}
|
||||
|
||||
RegisterAddressPtr(name, addr);
|
||||
return;
|
||||
}
|
||||
|
||||
private:
|
||||
std::map<std::string, AddressPtr> addresses_;
|
||||
std::vector<std::unique_ptr<float[]>> float_arrays_;
|
||||
std::vector<std::unique_ptr<int[]>> int32_arrays_;
|
||||
std::vector<std::unique_ptr<size_t[]>> uint64_arrays_;
|
||||
|
||||
void StoreInt32Array(std::unique_ptr<int[]> *array);
|
||||
void StoreFloatArray(std::unique_ptr<float[]> *array);
|
||||
void StoreUint64Array(std::unique_ptr<size_t[]> *array);
|
||||
|
||||
template <typename T, typename S>
|
||||
std::unique_ptr<T[]> CastUniquePtr(std::unique_ptr<S[]> *array) {
|
||||
return std::unique_ptr<T[]>{reinterpret_cast<T *>(array->release())};
|
||||
}
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_MEMORY_REGISTER_H_
|
|
@ -0,0 +1,321 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/server/parameter_aggregator.h"
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
bool ParameterAggregator::Init(const CNodePtr &cnode, size_t required_count) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
memory_register_ = std::make_shared<MemoryRegister>();
|
||||
MS_EXCEPTION_IF_NULL(memory_register_);
|
||||
|
||||
required_push_count_ = required_count;
|
||||
// The required_pull_count_ is the count for Pull, which should be the same as required_push_count_.
|
||||
// required_pull_count_ normally used in parameter server training mode.
|
||||
required_pull_count_ = required_count;
|
||||
|
||||
MS_LOG(DEBUG) << "Start initializing kernels for " << AnfAlgo::GetCNodeName(cnode);
|
||||
InitAggregationKernels(cnode);
|
||||
InitOptimizerKernels(cnode);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ParameterAggregator::UpdateData(const std::map<std::string, Address> &new_data) {
|
||||
std::map<std::string, AddressPtr> &name_to_addr = memory_register_->addresses();
|
||||
for (const auto &data : new_data) {
|
||||
const std::string &name = data.first;
|
||||
if (name_to_addr.count(name) == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "Update data for " << name << ". Destination size: " << name_to_addr[name]->size
|
||||
<< ". Source size: " << data.second.size;
|
||||
int ret = memcpy_s(name_to_addr[name]->addr, name_to_addr[name]->size, data.second.addr, data.second.size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ParameterAggregator::LaunchAggregators() {
|
||||
for (auto &aggregator_with_params : aggregation_kernel_parameters_) {
|
||||
KernelParams ¶ms = aggregator_with_params.second;
|
||||
std::shared_ptr<kernel::AggregationKernel> aggr_kernel = aggregator_with_params.first;
|
||||
RETURN_IF_NULL(aggr_kernel, false);
|
||||
|
||||
bool ret = aggr_kernel->Launch(params.inputs, params.workspace, params.outputs);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "Launching aggregation kernel " << typeid(aggr_kernel.get()).name() << " failed.";
|
||||
continue;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ParameterAggregator::LaunchOptimizers() {
|
||||
for (auto &optimizer_with_params : optimizer_kernel_parameters_) {
|
||||
KernelParams ¶ms = optimizer_with_params.second;
|
||||
std::shared_ptr<kernel::OptimizerKernel> optimizer_kernel = optimizer_with_params.first;
|
||||
RETURN_IF_NULL(optimizer_kernel, false);
|
||||
|
||||
bool ret = optimizer_kernel->Launch(params.inputs, params.workspace, params.outputs);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "Launching optimizer kernel " << typeid(optimizer_kernel.get()).name() << " failed.";
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// As long as all the optimizer kernels are launched, consider optimizing for this ParameterAggregator as done.
|
||||
optimizing_done_ = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
AddressPtr ParameterAggregator::Pull() {
|
||||
if (memory_register_ == nullptr) {
|
||||
MS_LOG(ERROR)
|
||||
<< "The memory register of ParameterAggregator is nullptr. Please initialize ParameterAggregator first.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
current_pull_count_++;
|
||||
if (current_pull_count_ == required_pull_count_) {
|
||||
pulling_done_ = true;
|
||||
}
|
||||
MS_LOG(DEBUG) << "The " << current_pull_count_ << " time of Pull. Pulling done status: " << pulling_done_;
|
||||
|
||||
std::map<std::string, AddressPtr> &name_to_addr = memory_register_->addresses();
|
||||
return name_to_addr["weight"];
|
||||
}
|
||||
|
||||
AddressPtr ParameterAggregator::GetWeight() {
|
||||
if (memory_register_ == nullptr) {
|
||||
MS_LOG(ERROR)
|
||||
<< "The memory register of ParameterAggregator is nullptr. Please initialize ParameterAggregator first.";
|
||||
return nullptr;
|
||||
}
|
||||
std::map<std::string, AddressPtr> &name_to_addr = memory_register_->addresses();
|
||||
return name_to_addr["weight"];
|
||||
}
|
||||
|
||||
void ParameterAggregator::ResetAggregationStatus() {
|
||||
for (auto &aggregator_with_params : aggregation_kernel_parameters_) {
|
||||
std::shared_ptr<kernel::AggregationKernel> aggr_kernel = aggregator_with_params.first;
|
||||
if (aggr_kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "The aggregation kernel is nullptr.";
|
||||
continue;
|
||||
}
|
||||
aggr_kernel->Reset();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void ParameterAggregator::ResetOptimizingStatus() { optimizing_done_ = false; }
|
||||
|
||||
void ParameterAggregator::ResetPullingStatus() {
|
||||
pulling_done_ = false;
|
||||
current_pull_count_ = 0;
|
||||
}
|
||||
|
||||
bool ParameterAggregator::IsAggregationDone() const {
|
||||
// Only consider aggregation done after each aggregation kernel is done.
|
||||
for (auto &aggregator_with_params : aggregation_kernel_parameters_) {
|
||||
std::shared_ptr<kernel::AggregationKernel> aggr_kernel = aggregator_with_params.first;
|
||||
RETURN_IF_NULL(aggr_kernel, false);
|
||||
if (!aggr_kernel->IsAggregationDone()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ParameterAggregator::IsOptimizingDone() const { return optimizing_done_; }
|
||||
|
||||
bool ParameterAggregator::IsPullingDone() const { return pulling_done_; }
|
||||
|
||||
bool ParameterAggregator::InitAggregationKernels(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
std::vector<std::string> aggr_kernel_names = SelectAggregationAlgorithm(cnode);
|
||||
for (const std::string &name : aggr_kernel_names) {
|
||||
auto aggr_kernel = kernel::AggregationKernelFactory::GetInstance().Create(name, cnode);
|
||||
if (aggr_kernel == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Fail to create aggregation kernel " << name << " for " << AnfAlgo::GetCNodeName(cnode);
|
||||
return false;
|
||||
}
|
||||
|
||||
// set_done_count must be called before InitKernel because InitKernel may use this count.
|
||||
aggr_kernel->set_done_count(required_push_count_);
|
||||
aggr_kernel->InitKernel(cnode);
|
||||
|
||||
const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info = aggr_kernel->reuse_kernel_node_inputs_info();
|
||||
if (!AssignMemory(aggr_kernel, cnode, reuse_kernel_node_inputs_info, memory_register_)) {
|
||||
MS_LOG(EXCEPTION) << "Assigning memory for kernel " << name << " failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!GenerateAggregationKernelParams(aggr_kernel, memory_register_)) {
|
||||
MS_LOG(EXCEPTION) << "Generating aggregation kernel parameters for " << name << " failed.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ParameterAggregator::InitOptimizerKernels(const CNodePtr &cnode) {
|
||||
// if (PSContext::instance()->server_mode() == kServerModeFL) {
|
||||
// MS_LOG(DEBUG) << "Federated learning mode doesn't need optimizer kernel.";
|
||||
// return false;
|
||||
// }
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
const std::string &name = AnfAlgo::GetCNodeName(cnode);
|
||||
auto optimizer_kernel = kernel::OptimizerKernelFactory::GetInstance().Create(name, cnode);
|
||||
if (optimizer_kernel == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failed to create optimizer kernel for " << name;
|
||||
return false;
|
||||
}
|
||||
|
||||
optimizer_kernel->InitKernel(cnode);
|
||||
|
||||
const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info = optimizer_kernel->reuse_kernel_node_inputs_info();
|
||||
if (!AssignMemory(optimizer_kernel, cnode, reuse_kernel_node_inputs_info, memory_register_)) {
|
||||
MS_LOG(EXCEPTION) << "Assigning memory for kernel " << name << " failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!GenerateOptimizerKernelParams(optimizer_kernel, memory_register_)) {
|
||||
MS_LOG(ERROR) << "Generating optimizer kernel parameters failed.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename K>
|
||||
bool ParameterAggregator::AssignMemory(K server_kernel, const CNodePtr &cnode,
|
||||
const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info,
|
||||
std::shared_ptr<MemoryRegister> memory_register) {
|
||||
MS_EXCEPTION_IF_NULL(server_kernel);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
||||
const std::vector<std::string> &input_names = server_kernel->input_names();
|
||||
const std::vector<size_t> &input_size_list = server_kernel->GetInputSizeList();
|
||||
if (input_names.size() != input_size_list.size()) {
|
||||
MS_LOG(EXCEPTION) << "Server kernel " << typeid(server_kernel.get()).name()
|
||||
<< " input number is not matched: input_names size is " << input_names.size()
|
||||
<< ", input_size_list size is " << input_size_list.size();
|
||||
return false;
|
||||
}
|
||||
|
||||
if (reuse_kernel_node_inputs_info.size() > input_names.size()) {
|
||||
MS_LOG(EXCEPTION) << "The reuse kernel node information number is invalid: got "
|
||||
<< reuse_kernel_node_inputs_info.size() << ", but input_names size is " << input_names.size();
|
||||
return false;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < input_names.size(); i++) {
|
||||
const std::string &name = input_names[i];
|
||||
if (memory_register->addresses().count(name) != 0) {
|
||||
MS_LOG(DEBUG) << "The memory for " << name << " is already assigned.";
|
||||
continue;
|
||||
}
|
||||
if (reuse_kernel_node_inputs_info.count(name) != 0) {
|
||||
// Reusing memory of the kernel node means the memory of the input is already assigned by the front end, which
|
||||
// is to say, the input node is a parameter node.
|
||||
size_t index = reuse_kernel_node_inputs_info.at(name);
|
||||
MS_LOG(INFO) << "Try to reuse memory of kernel node " << AnfAlgo::GetCNodeName(cnode) << " for parameter " << name
|
||||
<< ", kernel node index " << index;
|
||||
AddressPtr input_addr = GenerateParameterNodeAddrPtr(cnode, index);
|
||||
MS_EXCEPTION_IF_NULL(input_addr);
|
||||
memory_register->RegisterAddressPtr(name, input_addr);
|
||||
} else {
|
||||
MS_LOG(INFO) << "Assign new memory for " << name;
|
||||
auto input_addr = std::make_unique<char[]>(input_size_list[i]);
|
||||
MS_EXCEPTION_IF_NULL(input_addr);
|
||||
memory_register->RegisterArray(name, &input_addr, input_size_list[i]);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ParameterAggregator::GenerateAggregationKernelParams(const std::shared_ptr<kernel::AggregationKernel> aggr_kernel,
|
||||
const std::shared_ptr<MemoryRegister> memory_register) {
|
||||
RETURN_IF_NULL(aggr_kernel, false);
|
||||
RETURN_IF_NULL(memory_register, false);
|
||||
KernelParams aggr_params = {};
|
||||
|
||||
const std::vector<std::string> &input_names = aggr_kernel->input_names();
|
||||
std::transform(input_names.begin(), input_names.end(), std::back_inserter(aggr_params.inputs),
|
||||
[&](const std::string &name) { return memory_register->addresses()[name]; });
|
||||
|
||||
const std::vector<std::string> &workspace_names = aggr_kernel->workspace_names();
|
||||
std::transform(workspace_names.begin(), workspace_names.end(), std::back_inserter(aggr_params.workspace),
|
||||
[&](const std::string &name) { return memory_register->addresses()[name]; });
|
||||
|
||||
const std::vector<std::string> &output_names = aggr_kernel->output_names();
|
||||
std::transform(output_names.begin(), output_names.end(), std::back_inserter(aggr_params.outputs),
|
||||
[&](const std::string &name) { return memory_register->addresses()[name]; });
|
||||
|
||||
aggregation_kernel_parameters_.push_back(std::make_pair(aggr_kernel, aggr_params));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ParameterAggregator::GenerateOptimizerKernelParams(const std::shared_ptr<kernel::OptimizerKernel> optimizer_kernel,
|
||||
const std::shared_ptr<MemoryRegister> memory_register) {
|
||||
RETURN_IF_NULL(optimizer_kernel, false);
|
||||
RETURN_IF_NULL(memory_register, false);
|
||||
KernelParams optimizer_params = {};
|
||||
|
||||
const std::vector<std::string> &input_names = optimizer_kernel->input_names();
|
||||
std::transform(input_names.begin(), input_names.end(), std::back_inserter(optimizer_params.inputs),
|
||||
[&](const std::string &name) { return memory_register->addresses()[name]; });
|
||||
|
||||
const std::vector<std::string> &workspace_names = optimizer_kernel->workspace_names();
|
||||
std::transform(workspace_names.begin(), workspace_names.end(), std::back_inserter(optimizer_params.workspace),
|
||||
[&](const std::string &name) { return memory_register->addresses()[name]; });
|
||||
|
||||
const std::vector<std::string> &output_names = optimizer_kernel->output_names();
|
||||
std::transform(output_names.begin(), output_names.end(), std::back_inserter(optimizer_params.outputs),
|
||||
[&](const std::string &name) { return memory_register->addresses()[name]; });
|
||||
|
||||
optimizer_kernel_parameters_.push_back(std::make_pair(optimizer_kernel, optimizer_params));
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::string> ParameterAggregator::SelectAggregationAlgorithm(const CNodePtr &cnode) {
|
||||
std::vector<std::string> aggregation_algorithm = {};
|
||||
MS_LOG(INFO) << "Aggregation algorithm selection result: " << aggregation_algorithm;
|
||||
return aggregation_algorithm;
|
||||
}
|
||||
|
||||
template bool ParameterAggregator::AssignMemory(std::shared_ptr<kernel::OptimizerKernel> server_kernel,
|
||||
const CNodePtr &cnode,
|
||||
const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info,
|
||||
std::shared_ptr<MemoryRegister> memory_register);
|
||||
|
||||
template bool ParameterAggregator::AssignMemory(std::shared_ptr<kernel::AggregationKernel> server_kernel,
|
||||
const CNodePtr &cnode,
|
||||
const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info,
|
||||
std::shared_ptr<MemoryRegister> memory_register);
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,139 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_PARAMETER_AGGREGATOR_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_PARAMETER_AGGREGATOR_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include "ps/server/common.h"
|
||||
#include "ps/server/memory_register.h"
|
||||
#include "ps/server/kernel/aggregation_kernel_factory.h"
|
||||
#include "ps/server/kernel/optimizer_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
// Encapsulate the parameters for a kernel into a struct to make it convenient for ParameterAggregator to launch server
|
||||
// kernels.
|
||||
typedef struct {
|
||||
std::vector<AddressPtr> inputs;
|
||||
std::vector<AddressPtr> workspace;
|
||||
std::vector<AddressPtr> outputs;
|
||||
} KernelParams;
|
||||
|
||||
// ParameterAggregator includes methods for aggregating gradients and optimizing weights(launching aggregation and
|
||||
// optimizer kernels), getting weights, etc. It's not thread-safe, which means the caller must acquire lock before
|
||||
// calling ParameterAggregator methods concurrently.
|
||||
|
||||
// Each ParameterAggregator is corresponding to one weight for now.
|
||||
|
||||
// ParameterAggregator is stateful because the process of aggregation and optimizing could be stateful.
|
||||
// For example, the finite-state machine for the ParameterAggregator in parameter server training mode is below:
|
||||
// Initial->Aggregating->Aggregation done->Optimizing->Optimizing done->Pulling->Pull done->Initial.
|
||||
class ParameterAggregator {
|
||||
public:
|
||||
ParameterAggregator()
|
||||
: server_mode_(ServerMode::PARAMETER_SERVER),
|
||||
required_push_count_(0),
|
||||
required_pull_count_(0),
|
||||
current_pull_count_(0),
|
||||
aggregation_done_(false),
|
||||
optimizing_done_(false),
|
||||
pulling_done_(true),
|
||||
memory_register_(nullptr) {}
|
||||
~ParameterAggregator() = default;
|
||||
|
||||
// Initialize ParameterAggregator with a cnode. This cnode is normally a optimizer kernel for now.
|
||||
// The parameter required_count helps ParameterAggregator to judge the current status if it's stateful.
|
||||
bool Init(const CNodePtr &cnode, size_t required_count = 0);
|
||||
|
||||
// Update old data stored in ParameterAggregator with new data.
|
||||
// The data could have many meanings: weights, gradients, learning_rate, momentum, etc.
|
||||
bool UpdateData(const std::map<std::string, Address> &new_data);
|
||||
|
||||
// Launch aggregators/optimizers of this ParameterAggregator in order.
|
||||
bool LaunchAggregators();
|
||||
bool LaunchOptimizers();
|
||||
|
||||
// The implementation for primitive Pull in parameter server training mode.
|
||||
// Every call of this method will increase the count for pull by 1.
|
||||
AddressPtr Pull();
|
||||
|
||||
// Different from the method Pull, this method simply returns the weight of this ParameterAggregator without causing
|
||||
// any change of status.
|
||||
AddressPtr GetWeight();
|
||||
|
||||
// After aggregation/optimizing/pulling of one iteration is done, caller must reset the status to ensure the
|
||||
// correctness of the aggregation/optimizing/pulling for next iteration.
|
||||
void ResetAggregationStatus();
|
||||
void ResetOptimizingStatus();
|
||||
void ResetPullingStatus();
|
||||
|
||||
// Returns the aggregation/optimizing/pulling status to the caller.
|
||||
bool IsAggregationDone() const;
|
||||
bool IsOptimizingDone() const;
|
||||
bool IsPullingDone() const;
|
||||
|
||||
private:
|
||||
// Initializing aggregation/optimizer kenerls based on the cnode. The reason of this is described in the file
|
||||
// kernel/kernel_factory.h.
|
||||
bool InitAggregationKernels(const CNodePtr &cnode);
|
||||
bool InitOptimizerKernels(const CNodePtr &cnode);
|
||||
|
||||
// Assign memory for server kernel K(AggregationKernel/OptimizerKernel).
|
||||
// The memory assigned can be accessed by MemoryRegister. The memory could be weights, gradients, learning_rate,
|
||||
// momentum, etc.
|
||||
template <typename K>
|
||||
bool AssignMemory(K server_kernel, const CNodePtr &cnode, const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info,
|
||||
std::shared_ptr<MemoryRegister> memory_register);
|
||||
|
||||
// Generate kernel parameters for aggregation/optimizer kernels. All the parameters is registered and stored in
|
||||
// memory_register.
|
||||
bool GenerateAggregationKernelParams(const std::shared_ptr<kernel::AggregationKernel> aggr_kernel,
|
||||
const std::shared_ptr<MemoryRegister> memory_register);
|
||||
bool GenerateOptimizerKernelParams(const std::shared_ptr<kernel::OptimizerKernel> optim_kernel,
|
||||
const std::shared_ptr<MemoryRegister> memory_register);
|
||||
|
||||
// The selection of the aggregation algorithm depends on multiple factors. For example, server mode, user
|
||||
// configuration, etc.
|
||||
std::vector<std::string> SelectAggregationAlgorithm(const CNodePtr &cnode);
|
||||
|
||||
ServerMode server_mode_;
|
||||
size_t required_push_count_;
|
||||
size_t required_pull_count_;
|
||||
size_t current_pull_count_;
|
||||
|
||||
// The status of aggregation/optimizing/pulling.
|
||||
bool aggregation_done_;
|
||||
bool optimizing_done_;
|
||||
bool pulling_done_;
|
||||
|
||||
// ParameterAggregator stores all data that it needs for aggregation, optimizing, etc.
|
||||
std::shared_ptr<MemoryRegister> memory_register_;
|
||||
|
||||
// Update could have multiple aggregation and optimizer server kernels.
|
||||
// Here stores multiple pairs of server kernels to parameters of their Launch function.
|
||||
std::vector<std::pair<std::shared_ptr<kernel::AggregationKernel>, KernelParams>> aggregation_kernel_parameters_;
|
||||
std::vector<std::pair<std::shared_ptr<kernel::OptimizerKernel>, KernelParams>> optimizer_kernel_parameters_;
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_PARAMETER_AGGREGATOR_H_
|
|
@ -168,6 +168,7 @@ list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/parameter_serve
|
|||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/gpu/gpu_ps_cache.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/server/kernel/apply_momentum_kernel.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_fusion.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/post_batch_norm_add_relu_fusion.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc")
|
||||
|
|
Loading…
Reference in New Issue