forked from mindspore-Ecosystem/mindspore
!16644 Add server round kernel for hybrid.
From: @zpac Reviewed-by: @limingqi107,@cristoval Signed-off-by: @cristoval
This commit is contained in:
commit
cdd032c237
|
@ -463,7 +463,14 @@ bool OptInlineAction(const ResourcePtr &res) {
|
|||
|
||||
bool GeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kGePasses); }
|
||||
|
||||
bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPasses); }
|
||||
bool VmOptimizeAction(const ResourcePtr &res) {
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
if (ps::PSContext::instance()->is_ps_mode()) {
|
||||
kVmPasses.push_back({"server_communication_op_fusion", ps::Util::FuseServerCommOps});
|
||||
}
|
||||
#endif
|
||||
return OptimizeAction(res, kVmPasses);
|
||||
}
|
||||
|
||||
bool PynativeOptimizeAction(const ResourcePtr &resource) {
|
||||
WITH(MsProfile::GetProfile())[&resource]() { (void)OptimizeAction(resource, kPynativePasses); };
|
||||
|
@ -613,7 +620,7 @@ bool ExecuteAction(const ResourcePtr &res) {
|
|||
}
|
||||
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
bool StartPSWorkerAction(const ResourcePtr &) {
|
||||
bool StartPSWorkerAction(const ResourcePtr &res) {
|
||||
ps::Worker::GetInstance().Run();
|
||||
return true;
|
||||
}
|
||||
|
@ -632,7 +639,8 @@ bool StartPSServerAction(const ResourcePtr &res) {
|
|||
bool StartServerAction(const ResourcePtr &res) {
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
const std::string &server_mode_ = ps::PSContext::instance()->server_mode();
|
||||
size_t worker_num = ps::PSContext::instance()->initial_worker_num();
|
||||
uint32_t worker_num = ps::PSContext::instance()->initial_worker_num();
|
||||
uint32_t server_num = ps::PSContext::instance()->initial_server_num();
|
||||
uint64_t fl_server_port = ps::PSContext::instance()->fl_server_port();
|
||||
|
||||
// Update model threshold is a certain ratio of start_fl_job threshold.
|
||||
|
@ -646,7 +654,9 @@ bool StartServerAction(const ResourcePtr &res) {
|
|||
std::vector<ps::server::RoundConfig> rounds_config = {
|
||||
{"startFLJob", true, start_fl_job_time_window, true, start_fl_job_threshold},
|
||||
{"updateModel", true, update_model_time_window, true, update_model_threshold},
|
||||
{"getModel"}};
|
||||
{"getModel"},
|
||||
{"pullWeight"},
|
||||
{"pushWeight", false, 3000, true, server_num}};
|
||||
|
||||
size_t executor_threshold = 0;
|
||||
if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) {
|
||||
|
|
|
@ -99,7 +99,7 @@ std::string PSContext::ms_role() const {
|
|||
|
||||
bool PSContext::is_worker() const {
|
||||
if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) {
|
||||
return role_ == kRoleOfWorker;
|
||||
return role_ == kEnvRoleOfWorker;
|
||||
}
|
||||
return is_worker_;
|
||||
}
|
||||
|
@ -185,7 +185,7 @@ const std::string &PSContext::server_mode() const { return server_mode_; }
|
|||
|
||||
void PSContext::set_ms_role(const std::string &role) {
|
||||
if (server_mode_ != kServerModeFL && server_mode_ != kServerModeHybrid) {
|
||||
MS_LOG(EXCEPTION) << "Only federated learning supports to set role by ps context.";
|
||||
MS_LOG(EXCEPTION) << "Only federated learning supports to set role by fl context.";
|
||||
return;
|
||||
}
|
||||
if (role != kEnvRoleOfWorker && role != kEnvRoleOfServer && role != kEnvRoleOfScheduler) {
|
||||
|
@ -198,7 +198,7 @@ void PSContext::set_ms_role(const std::string &role) {
|
|||
void PSContext::set_worker_num(uint32_t worker_num) {
|
||||
// Hybrid training mode only supports one worker for now.
|
||||
if (server_mode_ == kServerModeHybrid && worker_num != 1) {
|
||||
MS_LOG(EXCEPTION) << "The worker number should be set to 1 in hybrid training mode.";
|
||||
MS_LOG(EXCEPTION) << "The worker number should be set to 1 for now in hybrid training mode.";
|
||||
return;
|
||||
}
|
||||
worker_num_ = worker_num;
|
||||
|
|
|
@ -43,11 +43,11 @@ constexpr char kEnvRoleOfNotPS[] = "MS_NOT_PS";
|
|||
// 2: Server is in mixed training mode.
|
||||
// 3: Server enables sucure aggregation.
|
||||
// For example: 1010 stands for that the server is in federated learning mode and sucure aggregation is enabled.
|
||||
enum class ResetterRound { kNoNeedToReset, kUpdateModel, kReconstructSeccrets, kWorkerUploadWeights };
|
||||
enum class ResetterRound { kNoNeedToReset, kUpdateModel, kReconstructSeccrets, kPushWeight };
|
||||
const std::map<uint32_t, ResetterRound> kServerContextToResetRoundMap = {{0b0010, ResetterRound::kUpdateModel},
|
||||
{0b1010, ResetterRound::kReconstructSeccrets},
|
||||
{0b1100, ResetterRound::kWorkerUploadWeights},
|
||||
{0b0100, ResetterRound::kWorkerUploadWeights},
|
||||
{0b1100, ResetterRound::kPushWeight},
|
||||
{0b0100, ResetterRound::kPushWeight},
|
||||
{0b0100, ResetterRound::kUpdateModel}};
|
||||
|
||||
class PSContext {
|
||||
|
|
|
@ -131,12 +131,22 @@ const OptimParamNameToIndex kSparseAdamNameToIdx = {{"inputs",
|
|||
{"outputs", {}}};
|
||||
const OptimParamNameToIndex kSparseFtrlNameToIdx = {
|
||||
{"inputs", {{kWeight, 0}, {kAccumulation, 1}, {kFtrlLinear, 2}, {kGradient, 3}, {kIndices, 4}}}, {"outputs", {}}};
|
||||
const std::map<std::string, OptimParamNameToIndex> kNameToIdxMap = {
|
||||
{kApplyMomentumOpName, kMomentumNameToIdx},
|
||||
{kFusedSparseAdamName, kSparseAdamNameToIdx},
|
||||
{kSparseApplyFtrlOpName, kSparseFtrlNameToIdx},
|
||||
{kApplyAdamOpName, kAdamNameToIdx},
|
||||
};
|
||||
const OptimParamNameToIndex kAdamWeightDecayNameToIdx = {{"inputs",
|
||||
{{"weight", 0},
|
||||
{"m", 1},
|
||||
{"v", 2},
|
||||
{"lr", 3},
|
||||
{"beta1", 4},
|
||||
{"beta2", 5},
|
||||
{"eps", 6},
|
||||
{"weight_decay", 7},
|
||||
{"grad", 8}}},
|
||||
{"outputs", {}}};
|
||||
const std::map<std::string, OptimParamNameToIndex> kNameToIdxMap = {{kApplyMomentumOpName, kMomentumNameToIdx},
|
||||
{kFusedSparseAdamName, kSparseAdamNameToIdx},
|
||||
{kSparseApplyFtrlOpName, kSparseFtrlNameToIdx},
|
||||
{kApplyAdamOpName, kAdamNameToIdx},
|
||||
{"AdamWeightDecay", kAdamWeightDecayNameToIdx}};
|
||||
|
||||
constexpr uint32_t kLeaderServerRank = 0;
|
||||
constexpr size_t kWorkerMgrThreadPoolSize = 32;
|
||||
|
|
|
@ -45,7 +45,7 @@ void DistributedCountService::RegisterCounter(const std::string &name, size_t gl
|
|||
return;
|
||||
}
|
||||
if (global_threshold_count_.count(name) != 0) {
|
||||
MS_LOG(ERROR) << "Counter for " << name << " is already set.";
|
||||
MS_LOG(WARNING) << "Counter for " << name << " is already set.";
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -135,7 +135,7 @@ bool Executor::HandleModelUpdateAsync(const std::map<std::string, UploadData> &f
|
|||
return true;
|
||||
}
|
||||
|
||||
bool Executor::HandleOverwriteWeightsByKey(const std::map<std::string, Address> &feature_map) {
|
||||
bool Executor::HandlePushWeight(const std::map<std::string, Address> &feature_map) {
|
||||
for (const auto &trainable_param : feature_map) {
|
||||
const std::string ¶m_name = trainable_param.first;
|
||||
if (param_aggrs_.count(param_name) == 0) {
|
||||
|
@ -193,7 +193,7 @@ AddressPtr Executor::HandlePull(const std::string ¶m_name) {
|
|||
return addr;
|
||||
}
|
||||
|
||||
std::map<std::string, AddressPtr> Executor::HandleGetWeightsByKey(const std::vector<std::string> ¶m_names) {
|
||||
std::map<std::string, AddressPtr> Executor::HandlePullWeight(const std::vector<std::string> ¶m_names) {
|
||||
std::map<std::string, AddressPtr> weights;
|
||||
for (const auto ¶m_name : param_names) {
|
||||
if (param_aggrs_.count(param_name) == 0) {
|
||||
|
|
|
@ -63,11 +63,11 @@ class Executor {
|
|||
// asynchronously.
|
||||
bool HandleModelUpdateAsync(const std::map<std::string, UploadData> &feature_map);
|
||||
|
||||
// Forcibly overwrite specific weights in overwriteWeights message.
|
||||
bool HandleOverwriteWeightsByKey(const std::map<std::string, Address> &feature_map);
|
||||
// Overwrite the weights in server using pushed feature map.
|
||||
bool HandlePushWeight(const std::map<std::string, Address> &feature_map);
|
||||
|
||||
// Returns value for multiple trainable parameters passed by weight_names.
|
||||
std::map<std::string, AddressPtr> HandleGetWeightsByKey(const std::vector<std::string> ¶m_names);
|
||||
// Returns multiple trainable parameters passed by weight_names.
|
||||
std::map<std::string, AddressPtr> HandlePullWeight(const std::vector<std::string> ¶m_names);
|
||||
|
||||
// Reset the aggregation status for all aggregation kernels in the server.
|
||||
void ResetAggregationStatus();
|
||||
|
|
|
@ -95,6 +95,7 @@ class FedAvgKernel : public AggregationKernel {
|
|||
return;
|
||||
};
|
||||
DistributedCountService::GetInstance().RegisterCounter(name_, done_count_, {first_cnt_handler, last_cnt_handler});
|
||||
GenerateReuseKernelNodeInfo();
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -124,7 +125,6 @@ class FedAvgKernel : public AggregationKernel {
|
|||
participated_ = true;
|
||||
DistributedCountService::GetInstance().Count(
|
||||
name_, std::to_string(DistributedCountService::GetInstance().local_rank()) + "_" + std::to_string(accum_count_));
|
||||
GenerateReuseKernelNodeInfo();
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,132 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/server/kernel/round/pull_weight_kernel.h"
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "ps/server/model_store.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
uint64_t PullWeightKernel::retry_count_ = 0;
|
||||
void PullWeightKernel::InitKernel(size_t) {
|
||||
executor_ = &Executor::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(executor_);
|
||||
if (!executor_->initialized()) {
|
||||
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
bool PullWeightKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
MS_LOG(DEBUG) << "Launching PullWeightKernel kernel.";
|
||||
void *req_data = inputs[0]->addr;
|
||||
std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
|
||||
if (fbb == nullptr || req_data == nullptr) {
|
||||
MS_LOG(ERROR) << "FBBuilder builder or req_data is nullptr.";
|
||||
return false;
|
||||
}
|
||||
|
||||
const schema::RequestPullWeight *pull_weight_req = flatbuffers::GetRoot<schema::RequestPullWeight>(req_data);
|
||||
if (pull_weight_req == nullptr) {
|
||||
std::string reason = "Building flatbuffers schema failed for RequestPullWeight";
|
||||
BuildPullWeightRsp(fbb, schema::ResponseCode_RequestError, reason, {});
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return false;
|
||||
}
|
||||
|
||||
PullWeight(fbb, pull_weight_req);
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PullWeightKernel::Reset() { return true; }
|
||||
|
||||
void PullWeightKernel::PullWeight(std::shared_ptr<FBBuilder> fbb, const schema::RequestPullWeight *pull_weight_req) {
|
||||
std::map<std::string, AddressPtr> feature_maps = {};
|
||||
size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num();
|
||||
size_t pull_weight_iter = static_cast<size_t>(pull_weight_req->iteration());
|
||||
// The PullWeight round should be in the same iteration as other rounds.
|
||||
if (pull_weight_iter != current_iter) {
|
||||
std::string reason = "PullWeight iteration " + std::to_string(pull_weight_iter) +
|
||||
" is invalid. Server current iteration: " + std::to_string(current_iter);
|
||||
BuildPullWeightRsp(fbb, schema::ResponseCode_RequestError, reason, feature_maps);
|
||||
MS_LOG(WARNING) << reason;
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<std::string> weight_names = {};
|
||||
auto weights_names_fbs = pull_weight_req->weight_names();
|
||||
for (size_t i = 0; i < weights_names_fbs->size(); i++) {
|
||||
weight_names.push_back(weights_names_fbs->Get(i)->str());
|
||||
}
|
||||
if (!executor_->IsWeightAggrDone(weight_names)) {
|
||||
retry_count_++;
|
||||
std::string reason = "The aggregation for the weights is not done yet.";
|
||||
BuildPullWeightRsp(fbb, schema::ResponseCode_SucNotReady, reason, feature_maps);
|
||||
if (retry_count_ % 10 == 0) {
|
||||
MS_LOG(WARNING) << reason << " Retry count is " << retry_count_;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
feature_maps = executor_->HandlePullWeight(weight_names);
|
||||
if (feature_maps.empty()) {
|
||||
std::string reason = "The feature_map is empty for the given weight names.";
|
||||
BuildPullWeightRsp(fbb, schema::ResponseCode_RequestError, reason, feature_maps);
|
||||
MS_LOG(WARNING) << reason;
|
||||
return;
|
||||
}
|
||||
|
||||
BuildPullWeightRsp(fbb, schema::ResponseCode_SUCCEED,
|
||||
"Pull weights by weight names for iteration " + std::to_string(pull_weight_iter) + " success.",
|
||||
feature_maps);
|
||||
return;
|
||||
}
|
||||
|
||||
void PullWeightKernel::BuildPullWeightRsp(std::shared_ptr<FBBuilder> fbb, const schema::ResponseCode retcode,
|
||||
const std::string &reason,
|
||||
const std::map<std::string, AddressPtr> &feature_maps) {
|
||||
auto fbs_reason = fbb->CreateString(reason);
|
||||
std::vector<flatbuffers::Offset<schema::FeatureMap>> fbs_feature_maps;
|
||||
for (auto feature_map : feature_maps) {
|
||||
auto fbs_weight_fullname = fbb->CreateString(feature_map.first);
|
||||
auto fbs_weight_data =
|
||||
fbb->CreateVector(reinterpret_cast<float *>(feature_map.second->addr), feature_map.second->size / sizeof(float));
|
||||
auto fbs_feature_map = schema::CreateFeatureMap(*(fbb.get()), fbs_weight_fullname, fbs_weight_data);
|
||||
fbs_feature_maps.push_back(fbs_feature_map);
|
||||
}
|
||||
auto fbs_feature_maps_vector = fbb->CreateVector(fbs_feature_maps);
|
||||
|
||||
schema::ResponsePullWeightBuilder rsp_pull_weight_builder(*(fbb.get()));
|
||||
rsp_pull_weight_builder.add_retcode(retcode);
|
||||
rsp_pull_weight_builder.add_reason(fbs_reason);
|
||||
rsp_pull_weight_builder.add_feature_map(fbs_feature_maps_vector);
|
||||
auto rsp_pull_weight = rsp_pull_weight_builder.Finish();
|
||||
fbb->Finish(rsp_pull_weight);
|
||||
return;
|
||||
}
|
||||
|
||||
REG_ROUND_KERNEL(pullWeight, PullWeightKernel)
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_PULL_WEIGHT_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_PULL_WEIGHT_KERNEL_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "ps/server/common.h"
|
||||
#include "ps/server/kernel/round/round_kernel.h"
|
||||
#include "ps/server/kernel/round/round_kernel_factory.h"
|
||||
#include "ps/server/executor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
class PullWeightKernel : public RoundKernel {
|
||||
public:
|
||||
PullWeightKernel() = default;
|
||||
~PullWeightKernel() override = default;
|
||||
|
||||
void InitKernel(size_t required_cnt) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
bool Reset() override;
|
||||
|
||||
private:
|
||||
void PullWeight(std::shared_ptr<FBBuilder> fbb, const schema::RequestPullWeight *pull_weight_req);
|
||||
void BuildPullWeightRsp(std::shared_ptr<FBBuilder> fbb, const schema::ResponseCode retcode, const std::string &reason,
|
||||
const std::map<std::string, AddressPtr> &feature_maps);
|
||||
|
||||
Executor *executor_;
|
||||
|
||||
// The count of retrying because the aggregation of the weights is not done.
|
||||
static uint64_t retry_count_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_PULL_WEIGHT_KERNEL_H_
|
|
@ -0,0 +1,131 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/server/kernel/round/push_weight_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
void PushWeightKernel::InitKernel(size_t) {
|
||||
executor_ = &Executor::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(executor_);
|
||||
if (!executor_->initialized()) {
|
||||
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
|
||||
return;
|
||||
}
|
||||
local_rank_ = DistributedCountService::GetInstance().local_rank();
|
||||
}
|
||||
|
||||
bool PushWeightKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
MS_LOG(INFO) << "Launching PushWeightKernel kernel.";
|
||||
void *req_data = inputs[0]->addr;
|
||||
std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
|
||||
if (fbb == nullptr || req_data == nullptr) {
|
||||
MS_LOG(ERROR) << "FBBuilder builder or req_data is nullptr.";
|
||||
return false;
|
||||
}
|
||||
|
||||
const schema::RequestPushWeight *push_weight_req = flatbuffers::GetRoot<schema::RequestPushWeight>(req_data);
|
||||
if (push_weight_req == nullptr) {
|
||||
std::string reason = "Building flatbuffers schema failed for RequestPushWeight";
|
||||
BuildPushWeightRsp(fbb, schema::ResponseCode_RequestError, reason);
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return false;
|
||||
}
|
||||
|
||||
PushWeight(fbb, push_weight_req);
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PushWeightKernel::Reset() {
|
||||
MS_LOG(INFO) << "PushWeightKernel reset!";
|
||||
StopTimer();
|
||||
DistributedCountService::GetInstance().ResetCounter(name_);
|
||||
return true;
|
||||
}
|
||||
|
||||
void PushWeightKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &) {
|
||||
if (PSContext::instance()->resetter_round() == ResetterRound::kPushWeight) {
|
||||
FinishIteration();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void PushWeightKernel::PushWeight(std::shared_ptr<FBBuilder> fbb, const schema::RequestPushWeight *push_weight_req) {
|
||||
if (fbb == nullptr || push_weight_req == nullptr) {
|
||||
return;
|
||||
}
|
||||
size_t iteration = static_cast<size_t>(push_weight_req->iteration());
|
||||
if (iteration != LocalMetaStore::GetInstance().curr_iter_num()) {
|
||||
std::string reason = "PushWeight iteration number is invalid:" + std::to_string(iteration) +
|
||||
", current iteration:" + std::to_string(LocalMetaStore::GetInstance().curr_iter_num());
|
||||
BuildPushWeightRsp(fbb, schema::ResponseCode_OutOfTime, reason);
|
||||
MS_LOG(ERROR) << reason;
|
||||
return;
|
||||
}
|
||||
|
||||
std::map<std::string, Address> upload_feature_map = ParseFeatureMap(push_weight_req);
|
||||
if (upload_feature_map.empty()) {
|
||||
std::string reason = "PushWeight overwrite feature_map is empty.";
|
||||
BuildPushWeightRsp(fbb, schema::ResponseCode_RequestError, reason);
|
||||
MS_LOG(ERROR) << reason;
|
||||
return;
|
||||
}
|
||||
|
||||
if (!executor_->HandlePushWeight(upload_feature_map)) {
|
||||
std::string reason = "OverwriteWeights failed.";
|
||||
BuildPushWeightRsp(fbb, schema::ResponseCode_SystemError, reason);
|
||||
MS_LOG(ERROR) << reason;
|
||||
return;
|
||||
}
|
||||
|
||||
DistributedCountService::GetInstance().Count(name_, std::to_string(local_rank_));
|
||||
BuildPushWeightRsp(fbb, schema::ResponseCode_SUCCEED, "PushWeight succeed.");
|
||||
return;
|
||||
}
|
||||
|
||||
std::map<std::string, Address> PushWeightKernel::ParseFeatureMap(const schema::RequestPushWeight *push_weight_req) {
|
||||
RETURN_IF_NULL(push_weight_req, {});
|
||||
std::map<std::string, Address> upload_feature_map;
|
||||
auto fbs_feature_map = push_weight_req->feature_map();
|
||||
for (size_t i = 0; i < fbs_feature_map->size(); i++) {
|
||||
std::string weight_full_name = fbs_feature_map->Get(i)->weight_fullname()->str();
|
||||
float *weight_data = const_cast<float *>(fbs_feature_map->Get(i)->data()->data());
|
||||
size_t weight_size = fbs_feature_map->Get(i)->data()->size() * sizeof(float);
|
||||
upload_feature_map[weight_full_name] = {weight_data, weight_size};
|
||||
}
|
||||
return upload_feature_map;
|
||||
}
|
||||
|
||||
void PushWeightKernel::BuildPushWeightRsp(std::shared_ptr<FBBuilder> fbb, const schema::ResponseCode retcode,
|
||||
const std::string &reason) {
|
||||
auto fbs_reason = fbb->CreateString(reason);
|
||||
schema::ResponsePushWeightBuilder rsp_push_weight_builder(*(fbb.get()));
|
||||
rsp_push_weight_builder.add_retcode(retcode);
|
||||
rsp_push_weight_builder.add_reason(fbs_reason);
|
||||
auto rsp_push_weight = rsp_push_weight_builder.Finish();
|
||||
fbb->Finish(rsp_push_weight);
|
||||
return;
|
||||
}
|
||||
|
||||
REG_ROUND_KERNEL(pushWeight, PushWeightKernel)
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_PUSH_WEIGHT_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_PUSH_WEIGHT_KERNEL_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "ps/server/common.h"
|
||||
#include "ps/server/kernel/round/round_kernel.h"
|
||||
#include "ps/server/kernel/round/round_kernel_factory.h"
|
||||
#include "ps/server/executor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
class PushWeightKernel : public RoundKernel {
|
||||
public:
|
||||
PushWeightKernel() = default;
|
||||
~PushWeightKernel() override = default;
|
||||
|
||||
void InitKernel(size_t threshold_count) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
bool Reset() override;
|
||||
void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) override;
|
||||
|
||||
private:
|
||||
void PushWeight(std::shared_ptr<FBBuilder> fbb, const schema::RequestPushWeight *push_weight_req);
|
||||
std::map<std::string, Address> ParseFeatureMap(const schema::RequestPushWeight *push_weight_req);
|
||||
void BuildPushWeightRsp(std::shared_ptr<FBBuilder> fbb, const schema::ResponseCode retcode,
|
||||
const std::string &reason);
|
||||
|
||||
Executor *executor_;
|
||||
uint32_t local_rank_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_PUSH_WEIGHT_KERNEL_H_
|
|
@ -36,8 +36,14 @@ bool ParameterAggregator::Init(const CNodePtr &cnode, size_t threshold_count) {
|
|||
required_pull_count_ = threshold_count;
|
||||
|
||||
MS_LOG(DEBUG) << "Start initializing kernels for " << AnfAlgo::GetCNodeName(cnode);
|
||||
InitAggregationKernels(cnode);
|
||||
InitOptimizerKernels(cnode);
|
||||
if (!InitAggregationKernels(cnode)) {
|
||||
MS_LOG(EXCEPTION) << "Initializing aggregation kernels failed.";
|
||||
return false;
|
||||
}
|
||||
if (!InitOptimizerKernels(cnode)) {
|
||||
MS_LOG(EXCEPTION) << "Initializing optimizer kernels failed.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -183,9 +189,10 @@ bool ParameterAggregator::InitAggregationKernels(const CNodePtr &cnode) {
|
|||
}
|
||||
|
||||
bool ParameterAggregator::InitOptimizerKernels(const CNodePtr &cnode) {
|
||||
if (PSContext::instance()->server_mode() == kServerModeFL) {
|
||||
if (PSContext::instance()->server_mode() == kServerModeFL ||
|
||||
PSContext::instance()->server_mode() == kServerModeHybrid) {
|
||||
MS_LOG(DEBUG) << "Federated learning mode doesn't need optimizer kernel.";
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
const std::string &name = AnfAlgo::GetCNodeName(cnode);
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "ps/util.h"
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ps/constants.h"
|
||||
#include "ps/ps_context.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
@ -128,5 +129,107 @@ void Util::ReduceSparseGradient(float *gradients, int *indices, const size_t ind
|
|||
|
||||
mindspore::kernel::SparseOptimizerCPUKernel::BucketReduceSparseGradient(param);
|
||||
}
|
||||
|
||||
bool Util::FuseServerCommOps(const pipeline::ResourcePtr &res) {
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
DoFusion(func_graph, kPullWeightOpName, kFusedPullWeightOpName);
|
||||
DoFusion(func_graph, kPushWeightOpName, kFusedPushWeightOpName);
|
||||
return true;
|
||||
}
|
||||
|
||||
void Util::DoFusion(FuncGraphPtr func_graph, const std::string &cnode_name, const std::string &fused_cnode_name) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
|
||||
|
||||
std::vector<AnfNodePtr> single_nodes;
|
||||
std::vector<std::string> weight_names;
|
||||
std::vector<int64_t> indices;
|
||||
for (const AnfNodePtr &node : node_list) {
|
||||
if (node != nullptr && node->isa<CNode>()) {
|
||||
if (AnfAlgo::GetCNodeName(node) == cnode_name) {
|
||||
single_nodes.push_back(node);
|
||||
|
||||
auto weight_name_value_node =
|
||||
AnfAlgo::GetInputNode(node->cast<CNodePtr>(), kNodeInputWeightNameOffset)->cast<ValueNodePtr>();
|
||||
const std::string &weight_name = GetValue<std::string>(weight_name_value_node->value());
|
||||
weight_names.push_back(weight_name);
|
||||
|
||||
auto weight_index_value_node =
|
||||
AnfAlgo::GetInputNode(node->cast<CNodePtr>(), kNodeInputWeightIndexOffset)->cast<ValueNodePtr>();
|
||||
int64_t weight_index = GetValue<int64_t>(weight_index_value_node->value());
|
||||
indices.push_back(weight_index);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto prim = std::make_shared<Primitive>(fused_cnode_name);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> fused_node_inputs = {};
|
||||
fused_node_inputs.push_back(NewValueNode(prim));
|
||||
std::for_each(single_nodes.begin(), single_nodes.end(), [&](AnfNodePtr node) {
|
||||
fused_node_inputs.push_back(AnfAlgo::GetInputNode(node->cast<CNodePtr>(), 0));
|
||||
});
|
||||
|
||||
auto fused_cnode = func_graph->NewCNode(fused_node_inputs);
|
||||
MS_EXCEPTION_IF_NULL(fused_cnode);
|
||||
AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(weight_names), fused_cnode);
|
||||
AnfAlgo::SetNodeAttr(kAttrIndex, MakeValue(indices), fused_cnode);
|
||||
AnfAlgo::SetNodeAttr(kAttrPrimitiveTarget, MakeValue(kCPUDevice), fused_cnode);
|
||||
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
fused_cnode->set_kernel_info(kernel_info);
|
||||
auto kernel_build_info = GenerateKernelBuildInfo(single_nodes);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, fused_cnode.get());
|
||||
|
||||
AbstractBasePtrList abstract_list;
|
||||
for (const auto &node : single_nodes) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
abstract_list.push_back(cnode->abstract());
|
||||
}
|
||||
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
||||
MS_EXCEPTION_IF_NULL(abstract_tuple);
|
||||
fused_cnode->set_abstract(abstract_tuple);
|
||||
|
||||
auto manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
for (const auto &node : single_nodes) {
|
||||
if (!manager->Replace(node, fused_cnode)) {
|
||||
MS_LOG(EXCEPTION) << "manager replace node failed";
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
kernel::KernelBuildInfoPtr Util::GenerateKernelBuildInfo(const std::vector<AnfNodePtr> &node_list) {
|
||||
std::vector<std::string> inputs_device_format;
|
||||
std::vector<std::string> outputs_device_format;
|
||||
std::vector<TypeId> inputs_device_type;
|
||||
std::vector<TypeId> outputs_device_type;
|
||||
std::vector<std::vector<size_t>> outputs_shape;
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
for (size_t idx = 0; idx < node_list.size(); ++idx) {
|
||||
auto cnode = utils::cast<CNodePtr>(node_list[idx]);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
inputs_device_format.push_back(kOpFormat_DEFAULT);
|
||||
inputs_device_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index));
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
|
||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||
outputs_device_format.push_back(kOpFormat_DEFAULT);
|
||||
outputs_device_type.push_back(AnfAlgo::GetOutputInferDataType(cnode, output_index));
|
||||
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index));
|
||||
}
|
||||
}
|
||||
builder.SetInputsFormat(inputs_device_format);
|
||||
builder.SetOutputsFormat(outputs_device_format);
|
||||
builder.SetInputsDeviceType(inputs_device_type);
|
||||
builder.SetOutputsDeviceType(outputs_device_type);
|
||||
return builder.Build();
|
||||
}
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,8 +18,10 @@
|
|||
#define MINDSPORE_CCSRC_PS_UTIL_H_
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include "frontend/optimizer/optimizer.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "backend/kernel_compiler/cpu/sparse_optimizer_cpu_kernel.h"
|
||||
|
@ -35,6 +37,9 @@ struct ParamInitInfo {
|
|||
float init_val_{0};
|
||||
};
|
||||
|
||||
constexpr size_t kNodeInputWeightNameOffset = 1;
|
||||
constexpr size_t kNodeInputWeightIndexOffset = 2;
|
||||
|
||||
class Util {
|
||||
public:
|
||||
static bool IsRoleOfPServer();
|
||||
|
@ -48,8 +53,12 @@ class Util {
|
|||
static void ReduceSparseGradient(float *gradients, int *indices, const size_t indices_size, size_t segment_size,
|
||||
const size_t first_dim_size, const size_t outer_dim_size,
|
||||
mindspore::kernel::SparseGradient<int> *unique_sparse_grad);
|
||||
static bool FuseServerCommOps(const pipeline::ResourcePtr &res);
|
||||
|
||||
private:
|
||||
static void DoFusion(FuncGraphPtr func_graph, const std::string &cnode_name, const std::string &fused_cnode_name);
|
||||
static kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const std::vector<AnfNodePtr> &node_list);
|
||||
|
||||
static std::unordered_map<std::string, int64_t> optimizer_to_ids;
|
||||
static std::unordered_map<int64_t, std::string> id_to_optimizers;
|
||||
static std::unordered_map<int64_t, std::string> id_to_optimizer_nodes;
|
||||
|
|
|
@ -57,7 +57,8 @@ _set_ps_context_func_map = {
|
|||
}
|
||||
|
||||
_get_ps_context_func_map = {
|
||||
"enable_ps": ps_context().is_ps_mode
|
||||
"enable_ps": ps_context().is_ps_mode,
|
||||
"ms_role": ps_context().ms_role
|
||||
}
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue