Add iteration metrics

This commit is contained in:
ZPaC 2021-08-12 20:29:06 +08:00
parent b59e9c3e20
commit c26e07aadf
15 changed files with 622 additions and 6 deletions

View File

@ -84,6 +84,7 @@ if(NOT ENABLE_CPU OR WIN32)
list(REMOVE_ITEM CPU_SRC_LIST "cpu/fl/get_model_kernel.cc")
list(REMOVE_ITEM CPU_SRC_LIST "cpu/fl/start_fl_job_kernel.cc")
list(REMOVE_ITEM CPU_SRC_LIST "cpu/fl/update_model_kernel.cc")
list(REMOVE_ITEM CPU_SRC_LIST "cpu/fl/push_metrics_kernel.cc")
endif()
if(ENABLE_GPU)

View File

@ -0,0 +1,26 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/fl/push_metrics_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_CPU_KERNEL_T(
PushMetrics,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
PushMetricsKernel, float);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,125 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_FL_PUSH_METRICS_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_FL_PUSH_METRICS_H_
#include <vector>
#include <string>
#include <memory>
#include <functional>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#include "fl/worker/fl_worker.h"
namespace mindspore {
namespace kernel {
// The duration between two PushMetrics requests.
constexpr int kRetryDurationOfPushMetrics = 500;
// Retry for 30 minutes.
constexpr int kMaxRetryTime = 3600;
template <typename T>
class PushMetricsKernel : public CPUKernel {
public:
PushMetricsKernel() : fbb_(nullptr), total_iteration_(0) {}
~PushMetricsKernel() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &) {
if (inputs.size() != 2) {
MS_LOG(EXCEPTION) << "Input number of PushMetricsKernel should be " << 2 << ", but got " << inputs.size();
return false;
}
MS_EXCEPTION_IF_NULL(inputs[0]->addr);
MS_EXCEPTION_IF_NULL(inputs[1]->addr);
T loss = *(static_cast<float *>(inputs[0]->addr));
T accuracy = *(static_cast<float *>(inputs[1]->addr));
if (!BuildPushMetricsReq(fbb_, loss, accuracy)) {
MS_LOG(EXCEPTION) << "Building request for FusedPushWeight failed.";
return false;
}
uint32_t retry_time = 0;
std::shared_ptr<std::vector<unsigned char>> push_metrics_rsp_msg = nullptr;
do {
retry_time++;
if (!fl::worker::FLWorker::GetInstance().SendToServer(fl::kLeaderServerRank, fbb_->GetBufferPointer(),
fbb_->GetSize(), ps::core::TcpUserCommand::kPushMetrics,
&push_metrics_rsp_msg)) {
MS_LOG(WARNING) << "Sending request for PushMetrics to server 0 failed.";
std::this_thread::sleep_for(std::chrono::milliseconds(kRetryDurationOfPushMetrics));
continue;
} else {
break;
}
} while (retry_time > kMaxRetryTime);
flatbuffers::Verifier verifier(push_metrics_rsp_msg->data(), push_metrics_rsp_msg->size());
if (!verifier.VerifyBuffer<schema::ResponsePushMetrics>()) {
MS_LOG(EXCEPTION) << "The schema of ResponsePushMetrics is invalid.";
return false;
}
const schema::ResponsePushMetrics *push_metrics_rsp =
flatbuffers::GetRoot<schema::ResponsePushMetrics>(push_metrics_rsp_msg->data());
MS_EXCEPTION_IF_NULL(push_metrics_rsp);
auto response_code = push_metrics_rsp->retcode();
switch (response_code) {
case schema::ResponseCode_SUCCEED:
case schema::ResponseCode_OutOfTime:
break;
default:
MS_LOG(EXCEPTION) << "Launching push metrics for worker failed.";
}
MS_LOG(INFO) << "Push metrics for loss and accuracy success.";
fl::worker::FLWorker::GetInstance().SetIterationCompleted();
return true;
}
void Init(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
fbb_ = std::make_shared<fl::FBBuilder>();
MS_EXCEPTION_IF_NULL(fbb_);
input_size_list_.push_back(sizeof(float));
input_size_list_.push_back(sizeof(float));
output_size_list_.push_back(sizeof(float));
}
void InitKernel(const CNodePtr &kernel_node) { return; }
protected:
void InitSizeLists() { return; }
private:
bool BuildPushMetricsReq(const std::shared_ptr<fl::FBBuilder> &fbb, T loss, T accuracy) {
MS_EXCEPTION_IF_NULL(fbb);
schema::RequestPushMetricsBuilder req_push_metrics_builder(*(fbb.get()));
req_push_metrics_builder.add_loss(loss);
req_push_metrics_builder.add_accuracy(accuracy);
auto req_push_metrics = req_push_metrics_builder.Finish();
fbb->Finish(req_push_metrics);
return true;
}
std::shared_ptr<fl::FBBuilder> fbb_;
size_t total_iteration_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_FL_PUSH_METRICS_H_

View File

@ -20,6 +20,7 @@ if(NOT ENABLE_CPU OR WIN32)
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/get_secrets_kernel.cc")
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/reconstruct_secrets_kernel.cc")
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/share_secrets_kernel.cc")
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/push_metrics_kernel.cc")
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/params_info.cc")
list(REMOVE_ITEM _FL_SRC_FILES "server/consistent_hash_ring.cc")
list(REMOVE_ITEM _FL_SRC_FILES "server/iteration_timer.cc")
@ -34,6 +35,7 @@ if(NOT ENABLE_CPU OR WIN32)
list(REMOVE_ITEM _FL_SRC_FILES "server/model_store.cc")
list(REMOVE_ITEM _FL_SRC_FILES "server/round.cc")
list(REMOVE_ITEM _FL_SRC_FILES "server/server.cc")
list(REMOVE_ITEM _FL_SRC_FILES "server/iteration_metrics.cc")
list(REMOVE_ITEM _FL_SRC_FILES "worker/fl_worker.cc")
list(REMOVE_ITEM _FL_SRC_FILES "armour/secure_protocol/encrypt.cc")
list(REMOVE_ITEM _FL_SRC_FILES "armour/secure_protocol/key_agreement.cc")

View File

@ -65,6 +65,20 @@ struct CipherConfig {
size_t reconstruct_secrets_threshold = 0;
};
// Every instance is one training loop that runs fl_iteration_num iterations of federated learning.
// During every instance, server's training process could be controlled by scheduler, which will change the state of
// this instance.
enum class InstanceState {
// If this instance is in kRunning state, server could communicate with client/worker and the traning process moves
// on.
kRunning = 0,
// The server is not available for client/worker if in kDisable state.
kDisable,
// The server is not available for client/worker if in kDisable state. And this state means one instance has finished.
// In other words, fl_iteration_num iterations are completed.
kFinish
};
using mindspore::kernel::Address;
using mindspore::kernel::AddressPtr;
using mindspore::kernel::CPUKernel;

View File

@ -160,6 +160,10 @@ const std::vector<std::shared_ptr<Round>> &Iteration::rounds() const { return ro
bool Iteration::is_last_iteration_valid() const { return is_last_iteration_valid_; }
void Iteration::set_loss(float loss) { loss_ = loss; }
void Iteration::set_accuracy(float accuracy) { accuracy_ = accuracy; }
bool Iteration::SyncIteration(uint32_t rank) {
MS_ERROR_IF_NULL_W_RET_VAL(communicator_, false);
SyncIterationRequest sync_iter_req;

View File

@ -83,6 +83,10 @@ class Iteration {
bool is_last_iteration_valid() const;
// Set the instance metrics which will be called for each iteration.
void set_loss(float loss);
void set_accuracy(float accuracy);
private:
Iteration()
: server_node_(nullptr),
@ -91,7 +95,9 @@ class Iteration {
iteration_loop_count_(0),
iteration_num_(1),
is_last_iteration_valid_(true),
pinned_iter_num_(0) {
pinned_iter_num_(0),
loss_(0.0),
accuracy_(0.0) {
LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
}
~Iteration() = default;
@ -153,6 +159,12 @@ class Iteration {
// To avoid Next method is called multiple times in one iteration, we should mark the iteration number.
uint64_t pinned_iter_num_;
std::mutex pinned_mtx_;
// The training loss after this federated learning iteration, passed by worker.
float loss_;
// The evaluation result after this federated learning iteration, passed by worker.
float accuracy_;
};
} // namespace server
} // namespace fl

View File

@ -0,0 +1,117 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include <fstream>
#include "debug/common.h"
#include "ps/constants.h"
#include "fl/server/iteration_metrics.h"
namespace mindspore {
namespace fl {
namespace server {
bool IterationMetrics::Initialize() {
config_ = std::make_unique<ps::core::FileConfiguration>(config_file_path_);
MS_EXCEPTION_IF_NULL(config_);
if (!config_->Initialize()) {
MS_LOG(EXCEPTION) << "Initializing for metrics failed. Config file path " << config_file_path_
<< " may be invalid or not exist.";
return false;
}
// Read the metrics file path. If file is not set or not exits, create one.
if (!config_->Exists(kMetrics)) {
MS_LOG(WARNING) << "Metrics config is not set. Don't write metrics.";
return false;
} else {
std::string value = config_->Get(kMetrics, "");
nlohmann::json value_json;
try {
value_json = nlohmann::json::parse(value);
} catch (const std::exception &e) {
MS_LOG(EXCEPTION) << "The hyper-parameter data is not in json format.";
return false;
}
// Parse the storage type.
uint32_t storage_type = JsonGetKeyWithException<uint32_t>(value_json, ps::kStoreType);
if (std::to_string(storage_type) != ps::kFileStorage) {
MS_LOG(EXCEPTION) << "Storage type " << storage_type << " is not supported.";
return false;
}
// Parse storage file path.
std::string metrics_file_path = JsonGetKeyWithException<std::string>(value_json, ps::kStoreFilePath);
auto realpath = Common::GetRealPath(metrics_file_path);
if (!realpath.has_value()) {
MS_LOG(EXCEPTION) << "Get real path for " << metrics_file_path << " failed.";
return false;
}
metrics_file_.open(metrics_file_path, std::ios::ate | std::ios::out);
}
initialized_ = true;
return true;
}
bool IterationMetrics::Summarize() {
if (!metrics_file_.is_open()) {
metrics_file_.clear();
MS_LOG(ERROR) << "The metrics file is not opened.";
return false;
}
js_[kFLName] = fl_name_;
js_[kInstanceStatus] = kInstanceStateName.at(instance_state_);
js_[kFLIterationNum] = fl_iteration_num_;
js_[kCurIteration] = cur_iteration_num_;
js_[kJoinedClientNum] = joined_client_num_;
js_[kRejectedClientNum] = rejected_client_num_;
js_[kMetricsAuc] = accuracy_;
js_[kMetricsLoss] = loss_;
js_[kIterExecutionTime] = iteration_time_cost_;
metrics_file_ << js_ << "\n";
metrics_file_.flush();
return true;
}
bool IterationMetrics::Clear() { return true; }
void IterationMetrics::set_fl_name(const std::string &fl_name) { fl_name_ = fl_name; }
void IterationMetrics::set_fl_iteration_num(size_t fl_iteration_num) { fl_iteration_num_ = fl_iteration_num; }
void IterationMetrics::set_cur_iteration_num(size_t cur_iteration_num) { cur_iteration_num_ = cur_iteration_num; }
void IterationMetrics::set_instance_state(InstanceState state) { instance_state_ = state; }
void IterationMetrics::set_loss(float loss) { loss_ = loss; }
void IterationMetrics::set_accuracy(float acc) { accuracy_ = acc; }
void IterationMetrics::set_joined_client_num(size_t joined_client_num) { joined_client_num_ = joined_client_num; }
void IterationMetrics::set_rejected_client_num(size_t rejected_client_num) {
rejected_client_num_ = rejected_client_num;
}
void IterationMetrics::set_iteration_time_cost(uint64_t iteration_time_cost) {
iteration_time_cost_ = iteration_time_cost;
}
} // namespace server
} // namespace fl
} // namespace mindspore

View File

@ -0,0 +1,135 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FL_SERVER_ITERATION_METRICS_H_
#define MINDSPORE_CCSRC_FL_SERVER_ITERATION_METRICS_H_
#include <map>
#include <string>
#include <memory>
#include <fstream>
#include "ps/ps_context.h"
#include "ps/core/configuration.h"
#include "ps/core/file_configuration.h"
#include "fl/server/local_meta_store.h"
#include "fl/server/iteration.h"
namespace mindspore {
namespace fl {
namespace server {
constexpr auto kFLName = "flName";
constexpr auto kInstanceStatus = "instanceStatus";
constexpr auto kFLIterationNum = "flIterationNum";
constexpr auto kCurIteration = "currentIteration";
constexpr auto kJoinedClientNum = "joinedClientNum";
constexpr auto kRejectedClientNum = "rejectedClientNum";
constexpr auto kMetricsAuc = "metricsAuc";
constexpr auto kMetricsLoss = "metricsLoss";
constexpr auto kIterExecutionTime = "iterationExecutionTime";
const std::map<InstanceState, std::string> kInstanceStateName = {
{InstanceState::kRunning, "running"}, {InstanceState::kDisable, "disable"}, {InstanceState::kFinish, "finish"}};
template <typename T>
inline T JsonGetKeyWithException(const nlohmann::json &json, const std::string &key) {
if (!json.contains(key)) {
MS_LOG(EXCEPTION) << "The key " << key << "does not exist in json " << json.dump();
}
return json[key].get<T>();
}
constexpr auto kMetrics = "metrics";
class IterationMetrics {
public:
explicit IterationMetrics(const std::string &config_file)
: config_file_path_(config_file),
config_(nullptr),
fl_name_(""),
fl_iteration_num_(0),
cur_iteration_num_(0),
instance_state_(InstanceState::kFinish),
loss_(0.0),
accuracy_(0.0),
joined_client_num_(0),
rejected_client_num_(0),
iteration_time_cost_(0) {}
~IterationMetrics() = default;
bool Initialize();
// Gather the details of this iteration and output to the persistent storage.
bool Summarize();
// Clear data in persistent storage.
bool Clear();
// Setters for the metrics data.
void set_fl_name(const std::string &fl_name);
void set_fl_iteration_num(size_t fl_iteration_num);
void set_cur_iteration_num(size_t cur_iteration_num);
void set_instance_state(InstanceState state);
void set_loss(float loss);
void set_accuracy(float acc);
void set_joined_client_num(size_t joined_client_num);
void set_rejected_client_num(size_t rejected_client_num);
void set_iteration_time_cost(uint64_t iteration_time_cost);
private:
bool initialized_;
// This is the main config file set by ps context.
std::string config_file_path_;
std::unique_ptr<ps::core::FileConfiguration> config_;
// The metrics file object.
std::fstream metrics_file_;
// Json object of metrics data.
nlohmann::json js_;
// The federated learning job name. Set by ps_context.
std::string fl_name_;
// Federated learning iteration number. Set by ps_context.
// If this number of iterations are completed, one instance is finished.
size_t fl_iteration_num_;
// Current iteration number.
size_t cur_iteration_num_;
// Current instance state.
InstanceState instance_state_;
// The training loss after this federated learning iteration, passed by worker.
float loss_;
// The evaluation result after this federated learning iteration, passed by worker.
float accuracy_;
// The number of clients which join the federated aggregation.
size_t joined_client_num_;
// The number of clients which are not involved in federated aggregation.
size_t rejected_client_num_;
// The time cost in millisecond for this completed iteration.
uint64_t iteration_time_cost_;
};
} // namespace server
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FL_SERVER_ITERATION_METRICS_H_

View File

@ -0,0 +1,112 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include "fl/server/kernel/round/push_metrics_kernel.h"
#include "fl/server/iteration.h"
namespace mindspore {
namespace fl {
namespace server {
namespace kernel {
void PushMetricsKernel::InitKernel(size_t threshold_count) {
local_rank_ = DistributedCountService::GetInstance().local_rank();
}
bool PushMetricsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
MS_LOG(INFO) << "Launching PushMetricsKernel kernel.";
void *req_data = inputs[0]->addr;
std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
if (fbb == nullptr || req_data == nullptr) {
std::string reason = "FBBuilder builder or req_data is nullptr.";
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, reason.c_str(), reason.size());
return true;
}
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
if (!verifier.VerifyBuffer<schema::RequestPushMetrics>()) {
std::string reason = "The schema of RequestPushMetrics is invalid.";
BuildPushMetricsRsp(fbb, schema::ResponseCode_RequestError);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
const schema::RequestPushMetrics *push_metrics_req = flatbuffers::GetRoot<schema::RequestPushMetrics>(req_data);
if (push_metrics_req == nullptr) {
std::string reason = "Building flatbuffers schema failed for RequestPushMetrics";
BuildPushMetricsRsp(fbb, schema::ResponseCode_RequestError);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return false;
}
ResultCode result_code = PushMetrics(fbb, push_metrics_req);
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return ConvertResultCode(result_code);
}
bool PushMetricsKernel::Reset() {
MS_LOG(INFO) << "PushMetricsKernel reset!";
StopTimer();
DistributedCountService::GetInstance().ResetCounter(name_);
return true;
}
void PushMetricsKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) {
if (ps::PSContext::instance()->resetter_round() == ps::ResetterRound::kPushMetrics) {
FinishIteration();
}
return;
}
ResultCode PushMetricsKernel::PushMetrics(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestPushMetrics *push_metrics_req) {
MS_ERROR_IF_NULL_W_RET_VAL(fbb, ResultCode::kSuccessAndReturn);
MS_ERROR_IF_NULL_W_RET_VAL(push_metrics_req, ResultCode::kSuccessAndReturn);
float loss = push_metrics_req->loss();
float accuracy = push_metrics_req->accuracy();
Iteration::GetInstance().set_loss(loss);
Iteration::GetInstance().set_accuracy(accuracy);
std::string count_reason = "";
if (!DistributedCountService::GetInstance().Count(name_, std::to_string(local_rank_), &count_reason)) {
std::string reason = "Count for push metrics request failed.";
BuildPushMetricsRsp(fbb, schema::ResponseCode_SystemError);
MS_LOG(ERROR) << reason;
return count_reason == kNetworkError ? ResultCode::kFail : ResultCode::kSuccessAndReturn;
}
BuildPushMetricsRsp(fbb, schema::ResponseCode_SUCCEED);
return ResultCode::kSuccess;
}
void PushMetricsKernel::BuildPushMetricsRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode) {
MS_ERROR_IF_NULL_WO_RET_VAL(fbb);
schema::ResponsePushMetricsBuilder rsp_push_metrics_builder(*(fbb.get()));
rsp_push_metrics_builder.add_retcode(retcode);
auto rsp_push_metrics = rsp_push_metrics_builder.Finish();
fbb->Finish(rsp_push_metrics);
}
REG_ROUND_KERNEL(pushMetrics, PushMetricsKernel)
} // namespace kernel
} // namespace server
} // namespace fl
} // namespace mindspore

View File

@ -0,0 +1,54 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_PUSH_METRICS_KERNEL_H_
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_PUSH_METRICS_KERNEL_H_
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "fl/server/common.h"
#include "fl/server/kernel/round/round_kernel.h"
#include "fl/server/kernel/round/round_kernel_factory.h"
#include "fl/server/executor.h"
namespace mindspore {
namespace fl {
namespace server {
namespace kernel {
class PushMetricsKernel : public RoundKernel {
public:
PushMetricsKernel() = default;
~PushMetricsKernel() override = default;
void InitKernel(size_t threshold_count);
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
bool Reset() override;
void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) override;
private:
ResultCode PushMetrics(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestPushMetrics *push_metrics_req);
void BuildPushMetricsRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode);
uint32_t local_rank_;
};
} // namespace kernel
} // namespace server
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_PUSH_METRICS_KERNEL_H_

View File

@ -42,6 +42,9 @@ constexpr uint32_t kWorkerSleepTimeForNetworking = 1000;
// The time duration between retrying when server is in safemode.
constexpr uint32_t kWorkerRetryDurationForSafeMode = 500;
// The leader server rank.
constexpr uint32_t kLeaderServerRank = 0;
enum class IterationState {
// This iteration is still in process.
kRunning,

View File

@ -54,7 +54,8 @@ enum class TcpUserCommand {
kEndLastIter,
kStartFLJob,
kUpdateModel,
kGetModel
kGetModel,
kPushMetrics
};
const std::unordered_map<TcpUserCommand, std::string> kUserCommandToMsgType = {
@ -75,7 +76,8 @@ const std::unordered_map<TcpUserCommand, std::string> kUserCommandToMsgType = {
{TcpUserCommand::kEndLastIter, "endLastIter"},
{TcpUserCommand::kStartFLJob, "startFLJob"},
{TcpUserCommand::kUpdateModel, "updateModel"},
{TcpUserCommand::kGetModel, "getModel"}};
{TcpUserCommand::kGetModel, "getModel"},
{TcpUserCommand::kPushMetrics, "pushMetrics"}};
class TcpCommunicator : public CommunicatorBase {
public:

View File

@ -46,11 +46,11 @@ constexpr char kNotEncryptType[] = "NOT_ENCRYPT";
// 2: Server is in mixed training mode.
// 3: Server enables pairwise encrypt algorithm.
// For example: 1010 stands for that the server is in federated learning mode and pairwise encrypt algorithm is enabled.
enum class ResetterRound { kNoNeedToReset, kUpdateModel, kReconstructSeccrets, kPushWeight };
enum class ResetterRound { kNoNeedToReset, kUpdateModel, kReconstructSeccrets, kPushWeight, kPushMetrics };
const std::map<uint32_t, ResetterRound> kServerContextToResetRoundMap = {{0b0010, ResetterRound::kUpdateModel},
{0b1010, ResetterRound::kReconstructSeccrets},
{0b1100, ResetterRound::kPushWeight},
{0b0100, ResetterRound::kPushWeight}};
{0b1100, ResetterRound::kPushMetrics},
{0b0100, ResetterRound::kPushMetrics}};
class PSContext {
public:

View File

@ -160,3 +160,12 @@ table ResponsePullWeight{
table FeatureMapList {
feature_map:[FeatureMap];
}
table RequestPushMetrics{
loss:float;
accuracy:float;
}
table ResponsePushMetrics{
retcode:int;
}