!16644 Add server round kernel for hybrid.

From: @zpac
Reviewed-by: @limingqi107,@cristoval
Signed-off-by: @cristoval
This commit is contained in:
mindspore-ci-bot 2021-06-03 19:36:45 +08:00 committed by Gitee
commit cdd032c237
16 changed files with 546 additions and 29 deletions

View File

@ -463,7 +463,14 @@ bool OptInlineAction(const ResourcePtr &res) {
bool GeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kGePasses); }
bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPasses); }
bool VmOptimizeAction(const ResourcePtr &res) {
#if (ENABLE_CPU && !_WIN32)
if (ps::PSContext::instance()->is_ps_mode()) {
kVmPasses.push_back({"server_communication_op_fusion", ps::Util::FuseServerCommOps});
}
#endif
return OptimizeAction(res, kVmPasses);
}
bool PynativeOptimizeAction(const ResourcePtr &resource) {
WITH(MsProfile::GetProfile())[&resource]() { (void)OptimizeAction(resource, kPynativePasses); };
@ -613,7 +620,7 @@ bool ExecuteAction(const ResourcePtr &res) {
}
#if (ENABLE_CPU && !_WIN32)
bool StartPSWorkerAction(const ResourcePtr &) {
bool StartPSWorkerAction(const ResourcePtr &res) {
ps::Worker::GetInstance().Run();
return true;
}
@ -632,7 +639,8 @@ bool StartPSServerAction(const ResourcePtr &res) {
bool StartServerAction(const ResourcePtr &res) {
FuncGraphPtr func_graph = res->func_graph();
const std::string &server_mode_ = ps::PSContext::instance()->server_mode();
size_t worker_num = ps::PSContext::instance()->initial_worker_num();
uint32_t worker_num = ps::PSContext::instance()->initial_worker_num();
uint32_t server_num = ps::PSContext::instance()->initial_server_num();
uint64_t fl_server_port = ps::PSContext::instance()->fl_server_port();
// Update model threshold is a certain ratio of start_fl_job threshold.
@ -646,7 +654,9 @@ bool StartServerAction(const ResourcePtr &res) {
std::vector<ps::server::RoundConfig> rounds_config = {
{"startFLJob", true, start_fl_job_time_window, true, start_fl_job_threshold},
{"updateModel", true, update_model_time_window, true, update_model_threshold},
{"getModel"}};
{"getModel"},
{"pullWeight"},
{"pushWeight", false, 3000, true, server_num}};
size_t executor_threshold = 0;
if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) {

View File

@ -99,7 +99,7 @@ std::string PSContext::ms_role() const {
bool PSContext::is_worker() const {
if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) {
return role_ == kRoleOfWorker;
return role_ == kEnvRoleOfWorker;
}
return is_worker_;
}
@ -185,7 +185,7 @@ const std::string &PSContext::server_mode() const { return server_mode_; }
void PSContext::set_ms_role(const std::string &role) {
if (server_mode_ != kServerModeFL && server_mode_ != kServerModeHybrid) {
MS_LOG(EXCEPTION) << "Only federated learning supports to set role by ps context.";
MS_LOG(EXCEPTION) << "Only federated learning supports to set role by fl context.";
return;
}
if (role != kEnvRoleOfWorker && role != kEnvRoleOfServer && role != kEnvRoleOfScheduler) {
@ -198,7 +198,7 @@ void PSContext::set_ms_role(const std::string &role) {
void PSContext::set_worker_num(uint32_t worker_num) {
// Hybrid training mode only supports one worker for now.
if (server_mode_ == kServerModeHybrid && worker_num != 1) {
MS_LOG(EXCEPTION) << "The worker number should be set to 1 in hybrid training mode.";
MS_LOG(EXCEPTION) << "The worker number should be set to 1 for now in hybrid training mode.";
return;
}
worker_num_ = worker_num;

View File

@ -43,11 +43,11 @@ constexpr char kEnvRoleOfNotPS[] = "MS_NOT_PS";
// 2: Server is in mixed training mode.
// 3: Server enables sucure aggregation.
// For example: 1010 stands for that the server is in federated learning mode and sucure aggregation is enabled.
enum class ResetterRound { kNoNeedToReset, kUpdateModel, kReconstructSeccrets, kWorkerUploadWeights };
enum class ResetterRound { kNoNeedToReset, kUpdateModel, kReconstructSeccrets, kPushWeight };
const std::map<uint32_t, ResetterRound> kServerContextToResetRoundMap = {{0b0010, ResetterRound::kUpdateModel},
{0b1010, ResetterRound::kReconstructSeccrets},
{0b1100, ResetterRound::kWorkerUploadWeights},
{0b0100, ResetterRound::kWorkerUploadWeights},
{0b1100, ResetterRound::kPushWeight},
{0b0100, ResetterRound::kPushWeight},
{0b0100, ResetterRound::kUpdateModel}};
class PSContext {

View File

@ -131,12 +131,22 @@ const OptimParamNameToIndex kSparseAdamNameToIdx = {{"inputs",
{"outputs", {}}};
const OptimParamNameToIndex kSparseFtrlNameToIdx = {
{"inputs", {{kWeight, 0}, {kAccumulation, 1}, {kFtrlLinear, 2}, {kGradient, 3}, {kIndices, 4}}}, {"outputs", {}}};
const std::map<std::string, OptimParamNameToIndex> kNameToIdxMap = {
{kApplyMomentumOpName, kMomentumNameToIdx},
{kFusedSparseAdamName, kSparseAdamNameToIdx},
{kSparseApplyFtrlOpName, kSparseFtrlNameToIdx},
{kApplyAdamOpName, kAdamNameToIdx},
};
const OptimParamNameToIndex kAdamWeightDecayNameToIdx = {{"inputs",
{{"weight", 0},
{"m", 1},
{"v", 2},
{"lr", 3},
{"beta1", 4},
{"beta2", 5},
{"eps", 6},
{"weight_decay", 7},
{"grad", 8}}},
{"outputs", {}}};
const std::map<std::string, OptimParamNameToIndex> kNameToIdxMap = {{kApplyMomentumOpName, kMomentumNameToIdx},
{kFusedSparseAdamName, kSparseAdamNameToIdx},
{kSparseApplyFtrlOpName, kSparseFtrlNameToIdx},
{kApplyAdamOpName, kAdamNameToIdx},
{"AdamWeightDecay", kAdamWeightDecayNameToIdx}};
constexpr uint32_t kLeaderServerRank = 0;
constexpr size_t kWorkerMgrThreadPoolSize = 32;

View File

@ -45,7 +45,7 @@ void DistributedCountService::RegisterCounter(const std::string &name, size_t gl
return;
}
if (global_threshold_count_.count(name) != 0) {
MS_LOG(ERROR) << "Counter for " << name << " is already set.";
MS_LOG(WARNING) << "Counter for " << name << " is already set.";
return;
}

View File

@ -135,7 +135,7 @@ bool Executor::HandleModelUpdateAsync(const std::map<std::string, UploadData> &f
return true;
}
bool Executor::HandleOverwriteWeightsByKey(const std::map<std::string, Address> &feature_map) {
bool Executor::HandlePushWeight(const std::map<std::string, Address> &feature_map) {
for (const auto &trainable_param : feature_map) {
const std::string &param_name = trainable_param.first;
if (param_aggrs_.count(param_name) == 0) {
@ -193,7 +193,7 @@ AddressPtr Executor::HandlePull(const std::string &param_name) {
return addr;
}
std::map<std::string, AddressPtr> Executor::HandleGetWeightsByKey(const std::vector<std::string> &param_names) {
std::map<std::string, AddressPtr> Executor::HandlePullWeight(const std::vector<std::string> &param_names) {
std::map<std::string, AddressPtr> weights;
for (const auto &param_name : param_names) {
if (param_aggrs_.count(param_name) == 0) {

View File

@ -63,11 +63,11 @@ class Executor {
// asynchronously.
bool HandleModelUpdateAsync(const std::map<std::string, UploadData> &feature_map);
// Forcibly overwrite specific weights in overwriteWeights message.
bool HandleOverwriteWeightsByKey(const std::map<std::string, Address> &feature_map);
// Overwrite the weights in server using pushed feature map.
bool HandlePushWeight(const std::map<std::string, Address> &feature_map);
// Returns value for multiple trainable parameters passed by weight_names.
std::map<std::string, AddressPtr> HandleGetWeightsByKey(const std::vector<std::string> &param_names);
// Returns multiple trainable parameters passed by weight_names.
std::map<std::string, AddressPtr> HandlePullWeight(const std::vector<std::string> &param_names);
// Reset the aggregation status for all aggregation kernels in the server.
void ResetAggregationStatus();

View File

@ -95,6 +95,7 @@ class FedAvgKernel : public AggregationKernel {
return;
};
DistributedCountService::GetInstance().RegisterCounter(name_, done_count_, {first_cnt_handler, last_cnt_handler});
GenerateReuseKernelNodeInfo();
return;
}
@ -124,7 +125,6 @@ class FedAvgKernel : public AggregationKernel {
participated_ = true;
DistributedCountService::GetInstance().Count(
name_, std::to_string(DistributedCountService::GetInstance().local_rank()) + "_" + std::to_string(accum_count_));
GenerateReuseKernelNodeInfo();
return true;
}

View File

@ -0,0 +1,132 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ps/server/kernel/round/pull_weight_kernel.h"
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "ps/server/model_store.h"
namespace mindspore {
namespace ps {
namespace server {
namespace kernel {
uint64_t PullWeightKernel::retry_count_ = 0;
void PullWeightKernel::InitKernel(size_t) {
executor_ = &Executor::GetInstance();
MS_EXCEPTION_IF_NULL(executor_);
if (!executor_->initialized()) {
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
return;
}
}
bool PullWeightKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
MS_LOG(DEBUG) << "Launching PullWeightKernel kernel.";
void *req_data = inputs[0]->addr;
std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
if (fbb == nullptr || req_data == nullptr) {
MS_LOG(ERROR) << "FBBuilder builder or req_data is nullptr.";
return false;
}
const schema::RequestPullWeight *pull_weight_req = flatbuffers::GetRoot<schema::RequestPullWeight>(req_data);
if (pull_weight_req == nullptr) {
std::string reason = "Building flatbuffers schema failed for RequestPullWeight";
BuildPullWeightRsp(fbb, schema::ResponseCode_RequestError, reason, {});
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return false;
}
PullWeight(fbb, pull_weight_req);
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
bool PullWeightKernel::Reset() { return true; }
void PullWeightKernel::PullWeight(std::shared_ptr<FBBuilder> fbb, const schema::RequestPullWeight *pull_weight_req) {
std::map<std::string, AddressPtr> feature_maps = {};
size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num();
size_t pull_weight_iter = static_cast<size_t>(pull_weight_req->iteration());
// The PullWeight round should be in the same iteration as other rounds.
if (pull_weight_iter != current_iter) {
std::string reason = "PullWeight iteration " + std::to_string(pull_weight_iter) +
" is invalid. Server current iteration: " + std::to_string(current_iter);
BuildPullWeightRsp(fbb, schema::ResponseCode_RequestError, reason, feature_maps);
MS_LOG(WARNING) << reason;
return;
}
std::vector<std::string> weight_names = {};
auto weights_names_fbs = pull_weight_req->weight_names();
for (size_t i = 0; i < weights_names_fbs->size(); i++) {
weight_names.push_back(weights_names_fbs->Get(i)->str());
}
if (!executor_->IsWeightAggrDone(weight_names)) {
retry_count_++;
std::string reason = "The aggregation for the weights is not done yet.";
BuildPullWeightRsp(fbb, schema::ResponseCode_SucNotReady, reason, feature_maps);
if (retry_count_ % 10 == 0) {
MS_LOG(WARNING) << reason << " Retry count is " << retry_count_;
}
return;
}
feature_maps = executor_->HandlePullWeight(weight_names);
if (feature_maps.empty()) {
std::string reason = "The feature_map is empty for the given weight names.";
BuildPullWeightRsp(fbb, schema::ResponseCode_RequestError, reason, feature_maps);
MS_LOG(WARNING) << reason;
return;
}
BuildPullWeightRsp(fbb, schema::ResponseCode_SUCCEED,
"Pull weights by weight names for iteration " + std::to_string(pull_weight_iter) + " success.",
feature_maps);
return;
}
void PullWeightKernel::BuildPullWeightRsp(std::shared_ptr<FBBuilder> fbb, const schema::ResponseCode retcode,
const std::string &reason,
const std::map<std::string, AddressPtr> &feature_maps) {
auto fbs_reason = fbb->CreateString(reason);
std::vector<flatbuffers::Offset<schema::FeatureMap>> fbs_feature_maps;
for (auto feature_map : feature_maps) {
auto fbs_weight_fullname = fbb->CreateString(feature_map.first);
auto fbs_weight_data =
fbb->CreateVector(reinterpret_cast<float *>(feature_map.second->addr), feature_map.second->size / sizeof(float));
auto fbs_feature_map = schema::CreateFeatureMap(*(fbb.get()), fbs_weight_fullname, fbs_weight_data);
fbs_feature_maps.push_back(fbs_feature_map);
}
auto fbs_feature_maps_vector = fbb->CreateVector(fbs_feature_maps);
schema::ResponsePullWeightBuilder rsp_pull_weight_builder(*(fbb.get()));
rsp_pull_weight_builder.add_retcode(retcode);
rsp_pull_weight_builder.add_reason(fbs_reason);
rsp_pull_weight_builder.add_feature_map(fbs_feature_maps_vector);
auto rsp_pull_weight = rsp_pull_weight_builder.Finish();
fbb->Finish(rsp_pull_weight);
return;
}
REG_ROUND_KERNEL(pullWeight, PullWeightKernel)
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace mindspore

View File

@ -0,0 +1,57 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_PULL_WEIGHT_KERNEL_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_PULL_WEIGHT_KERNEL_H_
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "ps/server/common.h"
#include "ps/server/kernel/round/round_kernel.h"
#include "ps/server/kernel/round/round_kernel_factory.h"
#include "ps/server/executor.h"
namespace mindspore {
namespace ps {
namespace server {
namespace kernel {
class PullWeightKernel : public RoundKernel {
public:
PullWeightKernel() = default;
~PullWeightKernel() override = default;
void InitKernel(size_t required_cnt) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
bool Reset() override;
private:
void PullWeight(std::shared_ptr<FBBuilder> fbb, const schema::RequestPullWeight *pull_weight_req);
void BuildPullWeightRsp(std::shared_ptr<FBBuilder> fbb, const schema::ResponseCode retcode, const std::string &reason,
const std::map<std::string, AddressPtr> &feature_maps);
Executor *executor_;
// The count of retrying because the aggregation of the weights is not done.
static uint64_t retry_count_;
};
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_PULL_WEIGHT_KERNEL_H_

View File

@ -0,0 +1,131 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ps/server/kernel/round/push_weight_kernel.h"
namespace mindspore {
namespace ps {
namespace server {
namespace kernel {
void PushWeightKernel::InitKernel(size_t) {
executor_ = &Executor::GetInstance();
MS_EXCEPTION_IF_NULL(executor_);
if (!executor_->initialized()) {
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
return;
}
local_rank_ = DistributedCountService::GetInstance().local_rank();
}
bool PushWeightKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
MS_LOG(INFO) << "Launching PushWeightKernel kernel.";
void *req_data = inputs[0]->addr;
std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
if (fbb == nullptr || req_data == nullptr) {
MS_LOG(ERROR) << "FBBuilder builder or req_data is nullptr.";
return false;
}
const schema::RequestPushWeight *push_weight_req = flatbuffers::GetRoot<schema::RequestPushWeight>(req_data);
if (push_weight_req == nullptr) {
std::string reason = "Building flatbuffers schema failed for RequestPushWeight";
BuildPushWeightRsp(fbb, schema::ResponseCode_RequestError, reason);
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return false;
}
PushWeight(fbb, push_weight_req);
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
bool PushWeightKernel::Reset() {
MS_LOG(INFO) << "PushWeightKernel reset!";
StopTimer();
DistributedCountService::GetInstance().ResetCounter(name_);
return true;
}
void PushWeightKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &) {
if (PSContext::instance()->resetter_round() == ResetterRound::kPushWeight) {
FinishIteration();
}
return;
}
void PushWeightKernel::PushWeight(std::shared_ptr<FBBuilder> fbb, const schema::RequestPushWeight *push_weight_req) {
if (fbb == nullptr || push_weight_req == nullptr) {
return;
}
size_t iteration = static_cast<size_t>(push_weight_req->iteration());
if (iteration != LocalMetaStore::GetInstance().curr_iter_num()) {
std::string reason = "PushWeight iteration number is invalid:" + std::to_string(iteration) +
", current iteration:" + std::to_string(LocalMetaStore::GetInstance().curr_iter_num());
BuildPushWeightRsp(fbb, schema::ResponseCode_OutOfTime, reason);
MS_LOG(ERROR) << reason;
return;
}
std::map<std::string, Address> upload_feature_map = ParseFeatureMap(push_weight_req);
if (upload_feature_map.empty()) {
std::string reason = "PushWeight overwrite feature_map is empty.";
BuildPushWeightRsp(fbb, schema::ResponseCode_RequestError, reason);
MS_LOG(ERROR) << reason;
return;
}
if (!executor_->HandlePushWeight(upload_feature_map)) {
std::string reason = "OverwriteWeights failed.";
BuildPushWeightRsp(fbb, schema::ResponseCode_SystemError, reason);
MS_LOG(ERROR) << reason;
return;
}
DistributedCountService::GetInstance().Count(name_, std::to_string(local_rank_));
BuildPushWeightRsp(fbb, schema::ResponseCode_SUCCEED, "PushWeight succeed.");
return;
}
std::map<std::string, Address> PushWeightKernel::ParseFeatureMap(const schema::RequestPushWeight *push_weight_req) {
RETURN_IF_NULL(push_weight_req, {});
std::map<std::string, Address> upload_feature_map;
auto fbs_feature_map = push_weight_req->feature_map();
for (size_t i = 0; i < fbs_feature_map->size(); i++) {
std::string weight_full_name = fbs_feature_map->Get(i)->weight_fullname()->str();
float *weight_data = const_cast<float *>(fbs_feature_map->Get(i)->data()->data());
size_t weight_size = fbs_feature_map->Get(i)->data()->size() * sizeof(float);
upload_feature_map[weight_full_name] = {weight_data, weight_size};
}
return upload_feature_map;
}
void PushWeightKernel::BuildPushWeightRsp(std::shared_ptr<FBBuilder> fbb, const schema::ResponseCode retcode,
const std::string &reason) {
auto fbs_reason = fbb->CreateString(reason);
schema::ResponsePushWeightBuilder rsp_push_weight_builder(*(fbb.get()));
rsp_push_weight_builder.add_retcode(retcode);
rsp_push_weight_builder.add_reason(fbs_reason);
auto rsp_push_weight = rsp_push_weight_builder.Finish();
fbb->Finish(rsp_push_weight);
return;
}
REG_ROUND_KERNEL(pushWeight, PushWeightKernel)
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace mindspore

View File

@ -0,0 +1,57 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_PUSH_WEIGHT_KERNEL_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_PUSH_WEIGHT_KERNEL_H_
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "ps/server/common.h"
#include "ps/server/kernel/round/round_kernel.h"
#include "ps/server/kernel/round/round_kernel_factory.h"
#include "ps/server/executor.h"
namespace mindspore {
namespace ps {
namespace server {
namespace kernel {
class PushWeightKernel : public RoundKernel {
public:
PushWeightKernel() = default;
~PushWeightKernel() override = default;
void InitKernel(size_t threshold_count) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
bool Reset() override;
void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) override;
private:
void PushWeight(std::shared_ptr<FBBuilder> fbb, const schema::RequestPushWeight *push_weight_req);
std::map<std::string, Address> ParseFeatureMap(const schema::RequestPushWeight *push_weight_req);
void BuildPushWeightRsp(std::shared_ptr<FBBuilder> fbb, const schema::ResponseCode retcode,
const std::string &reason);
Executor *executor_;
uint32_t local_rank_;
};
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_PUSH_WEIGHT_KERNEL_H_

View File

@ -36,8 +36,14 @@ bool ParameterAggregator::Init(const CNodePtr &cnode, size_t threshold_count) {
required_pull_count_ = threshold_count;
MS_LOG(DEBUG) << "Start initializing kernels for " << AnfAlgo::GetCNodeName(cnode);
InitAggregationKernels(cnode);
InitOptimizerKernels(cnode);
if (!InitAggregationKernels(cnode)) {
MS_LOG(EXCEPTION) << "Initializing aggregation kernels failed.";
return false;
}
if (!InitOptimizerKernels(cnode)) {
MS_LOG(EXCEPTION) << "Initializing optimizer kernels failed.";
return false;
}
return true;
}
@ -183,9 +189,10 @@ bool ParameterAggregator::InitAggregationKernels(const CNodePtr &cnode) {
}
bool ParameterAggregator::InitOptimizerKernels(const CNodePtr &cnode) {
if (PSContext::instance()->server_mode() == kServerModeFL) {
if (PSContext::instance()->server_mode() == kServerModeFL ||
PSContext::instance()->server_mode() == kServerModeHybrid) {
MS_LOG(DEBUG) << "Federated learning mode doesn't need optimizer kernel.";
return false;
return true;
}
MS_EXCEPTION_IF_NULL(cnode);
const std::string &name = AnfAlgo::GetCNodeName(cnode);

View File

@ -17,6 +17,7 @@
#include "ps/util.h"
#include <unordered_map>
#include <vector>
#include <memory>
#include "ps/constants.h"
#include "ps/ps_context.h"
#include "utils/ms_utils.h"
@ -128,5 +129,107 @@ void Util::ReduceSparseGradient(float *gradients, int *indices, const size_t ind
mindspore::kernel::SparseOptimizerCPUKernel::BucketReduceSparseGradient(param);
}
bool Util::FuseServerCommOps(const pipeline::ResourcePtr &res) {
FuncGraphPtr func_graph = res->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
DoFusion(func_graph, kPullWeightOpName, kFusedPullWeightOpName);
DoFusion(func_graph, kPushWeightOpName, kFusedPushWeightOpName);
return true;
}
void Util::DoFusion(FuncGraphPtr func_graph, const std::string &cnode_name, const std::string &fused_cnode_name) {
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
std::vector<AnfNodePtr> single_nodes;
std::vector<std::string> weight_names;
std::vector<int64_t> indices;
for (const AnfNodePtr &node : node_list) {
if (node != nullptr && node->isa<CNode>()) {
if (AnfAlgo::GetCNodeName(node) == cnode_name) {
single_nodes.push_back(node);
auto weight_name_value_node =
AnfAlgo::GetInputNode(node->cast<CNodePtr>(), kNodeInputWeightNameOffset)->cast<ValueNodePtr>();
const std::string &weight_name = GetValue<std::string>(weight_name_value_node->value());
weight_names.push_back(weight_name);
auto weight_index_value_node =
AnfAlgo::GetInputNode(node->cast<CNodePtr>(), kNodeInputWeightIndexOffset)->cast<ValueNodePtr>();
int64_t weight_index = GetValue<int64_t>(weight_index_value_node->value());
indices.push_back(weight_index);
}
}
}
auto prim = std::make_shared<Primitive>(fused_cnode_name);
MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> fused_node_inputs = {};
fused_node_inputs.push_back(NewValueNode(prim));
std::for_each(single_nodes.begin(), single_nodes.end(), [&](AnfNodePtr node) {
fused_node_inputs.push_back(AnfAlgo::GetInputNode(node->cast<CNodePtr>(), 0));
});
auto fused_cnode = func_graph->NewCNode(fused_node_inputs);
MS_EXCEPTION_IF_NULL(fused_cnode);
AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(weight_names), fused_cnode);
AnfAlgo::SetNodeAttr(kAttrIndex, MakeValue(indices), fused_cnode);
AnfAlgo::SetNodeAttr(kAttrPrimitiveTarget, MakeValue(kCPUDevice), fused_cnode);
auto kernel_info = std::make_shared<device::KernelInfo>();
MS_EXCEPTION_IF_NULL(kernel_info);
fused_cnode->set_kernel_info(kernel_info);
auto kernel_build_info = GenerateKernelBuildInfo(single_nodes);
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, fused_cnode.get());
AbstractBasePtrList abstract_list;
for (const auto &node : single_nodes) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
abstract_list.push_back(cnode->abstract());
}
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
MS_EXCEPTION_IF_NULL(abstract_tuple);
fused_cnode->set_abstract(abstract_tuple);
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
for (const auto &node : single_nodes) {
if (!manager->Replace(node, fused_cnode)) {
MS_LOG(EXCEPTION) << "manager replace node failed";
}
}
return;
}
kernel::KernelBuildInfoPtr Util::GenerateKernelBuildInfo(const std::vector<AnfNodePtr> &node_list) {
std::vector<std::string> inputs_device_format;
std::vector<std::string> outputs_device_format;
std::vector<TypeId> inputs_device_type;
std::vector<TypeId> outputs_device_type;
std::vector<std::vector<size_t>> outputs_shape;
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
for (size_t idx = 0; idx < node_list.size(); ++idx) {
auto cnode = utils::cast<CNodePtr>(node_list[idx]);
MS_EXCEPTION_IF_NULL(cnode);
size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
inputs_device_format.push_back(kOpFormat_DEFAULT);
inputs_device_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index));
}
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
for (size_t output_index = 0; output_index < output_num; ++output_index) {
outputs_device_format.push_back(kOpFormat_DEFAULT);
outputs_device_type.push_back(AnfAlgo::GetOutputInferDataType(cnode, output_index));
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index));
}
}
builder.SetInputsFormat(inputs_device_format);
builder.SetOutputsFormat(outputs_device_format);
builder.SetInputsDeviceType(inputs_device_type);
builder.SetOutputsDeviceType(outputs_device_type);
return builder.Build();
}
} // namespace ps
} // namespace mindspore

View File

@ -18,8 +18,10 @@
#define MINDSPORE_CCSRC_PS_UTIL_H_
#include <map>
#include <vector>
#include <string>
#include <unordered_map>
#include "frontend/optimizer/optimizer.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/kernel_compiler/common_utils.h"
#include "backend/kernel_compiler/cpu/sparse_optimizer_cpu_kernel.h"
@ -35,6 +37,9 @@ struct ParamInitInfo {
float init_val_{0};
};
constexpr size_t kNodeInputWeightNameOffset = 1;
constexpr size_t kNodeInputWeightIndexOffset = 2;
class Util {
public:
static bool IsRoleOfPServer();
@ -48,8 +53,12 @@ class Util {
static void ReduceSparseGradient(float *gradients, int *indices, const size_t indices_size, size_t segment_size,
const size_t first_dim_size, const size_t outer_dim_size,
mindspore::kernel::SparseGradient<int> *unique_sparse_grad);
static bool FuseServerCommOps(const pipeline::ResourcePtr &res);
private:
static void DoFusion(FuncGraphPtr func_graph, const std::string &cnode_name, const std::string &fused_cnode_name);
static kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const std::vector<AnfNodePtr> &node_list);
static std::unordered_map<std::string, int64_t> optimizer_to_ids;
static std::unordered_map<int64_t, std::string> id_to_optimizers;
static std::unordered_map<int64_t, std::string> id_to_optimizer_nodes;

View File

@ -57,7 +57,8 @@ _set_ps_context_func_map = {
}
_get_ps_context_func_map = {
"enable_ps": ps_context().is_ps_mode
"enable_ps": ps_context().is_ps_mode,
"ms_role": ps_context().ms_role
}