forked from mindspore-Ecosystem/mindspore
Add iteration metrics
This commit is contained in:
parent
b59e9c3e20
commit
c26e07aadf
|
@ -84,6 +84,7 @@ if(NOT ENABLE_CPU OR WIN32)
|
|||
list(REMOVE_ITEM CPU_SRC_LIST "cpu/fl/get_model_kernel.cc")
|
||||
list(REMOVE_ITEM CPU_SRC_LIST "cpu/fl/start_fl_job_kernel.cc")
|
||||
list(REMOVE_ITEM CPU_SRC_LIST "cpu/fl/update_model_kernel.cc")
|
||||
list(REMOVE_ITEM CPU_SRC_LIST "cpu/fl/push_metrics_kernel.cc")
|
||||
endif()
|
||||
|
||||
if(ENABLE_GPU)
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/cpu/fl/push_metrics_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
PushMetrics,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
PushMetricsKernel, float);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,125 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_FL_PUSH_METRICS_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_FL_PUSH_METRICS_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
#include "fl/worker/fl_worker.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
// The duration between two PushMetrics requests.
|
||||
constexpr int kRetryDurationOfPushMetrics = 500;
|
||||
// Retry for 30 minutes.
|
||||
constexpr int kMaxRetryTime = 3600;
|
||||
template <typename T>
|
||||
class PushMetricsKernel : public CPUKernel {
|
||||
public:
|
||||
PushMetricsKernel() : fbb_(nullptr), total_iteration_(0) {}
|
||||
~PushMetricsKernel() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &) {
|
||||
if (inputs.size() != 2) {
|
||||
MS_LOG(EXCEPTION) << "Input number of PushMetricsKernel should be " << 2 << ", but got " << inputs.size();
|
||||
return false;
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(inputs[0]->addr);
|
||||
MS_EXCEPTION_IF_NULL(inputs[1]->addr);
|
||||
T loss = *(static_cast<float *>(inputs[0]->addr));
|
||||
T accuracy = *(static_cast<float *>(inputs[1]->addr));
|
||||
|
||||
if (!BuildPushMetricsReq(fbb_, loss, accuracy)) {
|
||||
MS_LOG(EXCEPTION) << "Building request for FusedPushWeight failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t retry_time = 0;
|
||||
std::shared_ptr<std::vector<unsigned char>> push_metrics_rsp_msg = nullptr;
|
||||
do {
|
||||
retry_time++;
|
||||
if (!fl::worker::FLWorker::GetInstance().SendToServer(fl::kLeaderServerRank, fbb_->GetBufferPointer(),
|
||||
fbb_->GetSize(), ps::core::TcpUserCommand::kPushMetrics,
|
||||
&push_metrics_rsp_msg)) {
|
||||
MS_LOG(WARNING) << "Sending request for PushMetrics to server 0 failed.";
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(kRetryDurationOfPushMetrics));
|
||||
continue;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
} while (retry_time > kMaxRetryTime);
|
||||
|
||||
flatbuffers::Verifier verifier(push_metrics_rsp_msg->data(), push_metrics_rsp_msg->size());
|
||||
if (!verifier.VerifyBuffer<schema::ResponsePushMetrics>()) {
|
||||
MS_LOG(EXCEPTION) << "The schema of ResponsePushMetrics is invalid.";
|
||||
return false;
|
||||
}
|
||||
|
||||
const schema::ResponsePushMetrics *push_metrics_rsp =
|
||||
flatbuffers::GetRoot<schema::ResponsePushMetrics>(push_metrics_rsp_msg->data());
|
||||
MS_EXCEPTION_IF_NULL(push_metrics_rsp);
|
||||
auto response_code = push_metrics_rsp->retcode();
|
||||
switch (response_code) {
|
||||
case schema::ResponseCode_SUCCEED:
|
||||
case schema::ResponseCode_OutOfTime:
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "Launching push metrics for worker failed.";
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Push metrics for loss and accuracy success.";
|
||||
fl::worker::FLWorker::GetInstance().SetIterationCompleted();
|
||||
return true;
|
||||
}
|
||||
|
||||
void Init(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
fbb_ = std::make_shared<fl::FBBuilder>();
|
||||
MS_EXCEPTION_IF_NULL(fbb_);
|
||||
input_size_list_.push_back(sizeof(float));
|
||||
input_size_list_.push_back(sizeof(float));
|
||||
output_size_list_.push_back(sizeof(float));
|
||||
}
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) { return; }
|
||||
|
||||
protected:
|
||||
void InitSizeLists() { return; }
|
||||
|
||||
private:
|
||||
bool BuildPushMetricsReq(const std::shared_ptr<fl::FBBuilder> &fbb, T loss, T accuracy) {
|
||||
MS_EXCEPTION_IF_NULL(fbb);
|
||||
schema::RequestPushMetricsBuilder req_push_metrics_builder(*(fbb.get()));
|
||||
req_push_metrics_builder.add_loss(loss);
|
||||
req_push_metrics_builder.add_accuracy(accuracy);
|
||||
auto req_push_metrics = req_push_metrics_builder.Finish();
|
||||
fbb->Finish(req_push_metrics);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<fl::FBBuilder> fbb_;
|
||||
size_t total_iteration_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_FL_PUSH_METRICS_H_
|
|
@ -20,6 +20,7 @@ if(NOT ENABLE_CPU OR WIN32)
|
|||
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/get_secrets_kernel.cc")
|
||||
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/reconstruct_secrets_kernel.cc")
|
||||
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/share_secrets_kernel.cc")
|
||||
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/push_metrics_kernel.cc")
|
||||
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/params_info.cc")
|
||||
list(REMOVE_ITEM _FL_SRC_FILES "server/consistent_hash_ring.cc")
|
||||
list(REMOVE_ITEM _FL_SRC_FILES "server/iteration_timer.cc")
|
||||
|
@ -34,6 +35,7 @@ if(NOT ENABLE_CPU OR WIN32)
|
|||
list(REMOVE_ITEM _FL_SRC_FILES "server/model_store.cc")
|
||||
list(REMOVE_ITEM _FL_SRC_FILES "server/round.cc")
|
||||
list(REMOVE_ITEM _FL_SRC_FILES "server/server.cc")
|
||||
list(REMOVE_ITEM _FL_SRC_FILES "server/iteration_metrics.cc")
|
||||
list(REMOVE_ITEM _FL_SRC_FILES "worker/fl_worker.cc")
|
||||
list(REMOVE_ITEM _FL_SRC_FILES "armour/secure_protocol/encrypt.cc")
|
||||
list(REMOVE_ITEM _FL_SRC_FILES "armour/secure_protocol/key_agreement.cc")
|
||||
|
|
|
@ -65,6 +65,20 @@ struct CipherConfig {
|
|||
size_t reconstruct_secrets_threshold = 0;
|
||||
};
|
||||
|
||||
// Every instance is one training loop that runs fl_iteration_num iterations of federated learning.
|
||||
// During every instance, server's training process could be controlled by scheduler, which will change the state of
|
||||
// this instance.
|
||||
enum class InstanceState {
|
||||
// If this instance is in kRunning state, server could communicate with client/worker and the traning process moves
|
||||
// on.
|
||||
kRunning = 0,
|
||||
// The server is not available for client/worker if in kDisable state.
|
||||
kDisable,
|
||||
// The server is not available for client/worker if in kDisable state. And this state means one instance has finished.
|
||||
// In other words, fl_iteration_num iterations are completed.
|
||||
kFinish
|
||||
};
|
||||
|
||||
using mindspore::kernel::Address;
|
||||
using mindspore::kernel::AddressPtr;
|
||||
using mindspore::kernel::CPUKernel;
|
||||
|
|
|
@ -160,6 +160,10 @@ const std::vector<std::shared_ptr<Round>> &Iteration::rounds() const { return ro
|
|||
|
||||
bool Iteration::is_last_iteration_valid() const { return is_last_iteration_valid_; }
|
||||
|
||||
void Iteration::set_loss(float loss) { loss_ = loss; }
|
||||
|
||||
void Iteration::set_accuracy(float accuracy) { accuracy_ = accuracy; }
|
||||
|
||||
bool Iteration::SyncIteration(uint32_t rank) {
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(communicator_, false);
|
||||
SyncIterationRequest sync_iter_req;
|
||||
|
|
|
@ -83,6 +83,10 @@ class Iteration {
|
|||
|
||||
bool is_last_iteration_valid() const;
|
||||
|
||||
// Set the instance metrics which will be called for each iteration.
|
||||
void set_loss(float loss);
|
||||
void set_accuracy(float accuracy);
|
||||
|
||||
private:
|
||||
Iteration()
|
||||
: server_node_(nullptr),
|
||||
|
@ -91,7 +95,9 @@ class Iteration {
|
|||
iteration_loop_count_(0),
|
||||
iteration_num_(1),
|
||||
is_last_iteration_valid_(true),
|
||||
pinned_iter_num_(0) {
|
||||
pinned_iter_num_(0),
|
||||
loss_(0.0),
|
||||
accuracy_(0.0) {
|
||||
LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
|
||||
}
|
||||
~Iteration() = default;
|
||||
|
@ -153,6 +159,12 @@ class Iteration {
|
|||
// To avoid Next method is called multiple times in one iteration, we should mark the iteration number.
|
||||
uint64_t pinned_iter_num_;
|
||||
std::mutex pinned_mtx_;
|
||||
|
||||
// The training loss after this federated learning iteration, passed by worker.
|
||||
float loss_;
|
||||
|
||||
// The evaluation result after this federated learning iteration, passed by worker.
|
||||
float accuracy_;
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace fl
|
||||
|
|
|
@ -0,0 +1,117 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <string>
|
||||
#include <fstream>
|
||||
#include "debug/common.h"
|
||||
#include "ps/constants.h"
|
||||
#include "fl/server/iteration_metrics.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
bool IterationMetrics::Initialize() {
|
||||
config_ = std::make_unique<ps::core::FileConfiguration>(config_file_path_);
|
||||
MS_EXCEPTION_IF_NULL(config_);
|
||||
if (!config_->Initialize()) {
|
||||
MS_LOG(EXCEPTION) << "Initializing for metrics failed. Config file path " << config_file_path_
|
||||
<< " may be invalid or not exist.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Read the metrics file path. If file is not set or not exits, create one.
|
||||
if (!config_->Exists(kMetrics)) {
|
||||
MS_LOG(WARNING) << "Metrics config is not set. Don't write metrics.";
|
||||
return false;
|
||||
} else {
|
||||
std::string value = config_->Get(kMetrics, "");
|
||||
nlohmann::json value_json;
|
||||
try {
|
||||
value_json = nlohmann::json::parse(value);
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(EXCEPTION) << "The hyper-parameter data is not in json format.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Parse the storage type.
|
||||
uint32_t storage_type = JsonGetKeyWithException<uint32_t>(value_json, ps::kStoreType);
|
||||
if (std::to_string(storage_type) != ps::kFileStorage) {
|
||||
MS_LOG(EXCEPTION) << "Storage type " << storage_type << " is not supported.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Parse storage file path.
|
||||
std::string metrics_file_path = JsonGetKeyWithException<std::string>(value_json, ps::kStoreFilePath);
|
||||
auto realpath = Common::GetRealPath(metrics_file_path);
|
||||
if (!realpath.has_value()) {
|
||||
MS_LOG(EXCEPTION) << "Get real path for " << metrics_file_path << " failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
metrics_file_.open(metrics_file_path, std::ios::ate | std::ios::out);
|
||||
}
|
||||
|
||||
initialized_ = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IterationMetrics::Summarize() {
|
||||
if (!metrics_file_.is_open()) {
|
||||
metrics_file_.clear();
|
||||
MS_LOG(ERROR) << "The metrics file is not opened.";
|
||||
return false;
|
||||
}
|
||||
|
||||
js_[kFLName] = fl_name_;
|
||||
js_[kInstanceStatus] = kInstanceStateName.at(instance_state_);
|
||||
js_[kFLIterationNum] = fl_iteration_num_;
|
||||
js_[kCurIteration] = cur_iteration_num_;
|
||||
js_[kJoinedClientNum] = joined_client_num_;
|
||||
js_[kRejectedClientNum] = rejected_client_num_;
|
||||
js_[kMetricsAuc] = accuracy_;
|
||||
js_[kMetricsLoss] = loss_;
|
||||
js_[kIterExecutionTime] = iteration_time_cost_;
|
||||
metrics_file_ << js_ << "\n";
|
||||
metrics_file_.flush();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IterationMetrics::Clear() { return true; }
|
||||
|
||||
void IterationMetrics::set_fl_name(const std::string &fl_name) { fl_name_ = fl_name; }
|
||||
|
||||
void IterationMetrics::set_fl_iteration_num(size_t fl_iteration_num) { fl_iteration_num_ = fl_iteration_num; }
|
||||
|
||||
void IterationMetrics::set_cur_iteration_num(size_t cur_iteration_num) { cur_iteration_num_ = cur_iteration_num; }
|
||||
|
||||
void IterationMetrics::set_instance_state(InstanceState state) { instance_state_ = state; }
|
||||
|
||||
void IterationMetrics::set_loss(float loss) { loss_ = loss; }
|
||||
|
||||
void IterationMetrics::set_accuracy(float acc) { accuracy_ = acc; }
|
||||
|
||||
void IterationMetrics::set_joined_client_num(size_t joined_client_num) { joined_client_num_ = joined_client_num; }
|
||||
|
||||
void IterationMetrics::set_rejected_client_num(size_t rejected_client_num) {
|
||||
rejected_client_num_ = rejected_client_num;
|
||||
}
|
||||
|
||||
void IterationMetrics::set_iteration_time_cost(uint64_t iteration_time_cost) {
|
||||
iteration_time_cost_ = iteration_time_cost;
|
||||
}
|
||||
} // namespace server
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,135 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FL_SERVER_ITERATION_METRICS_H_
|
||||
#define MINDSPORE_CCSRC_FL_SERVER_ITERATION_METRICS_H_
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <fstream>
|
||||
#include "ps/ps_context.h"
|
||||
#include "ps/core/configuration.h"
|
||||
#include "ps/core/file_configuration.h"
|
||||
#include "fl/server/local_meta_store.h"
|
||||
#include "fl/server/iteration.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
constexpr auto kFLName = "flName";
|
||||
constexpr auto kInstanceStatus = "instanceStatus";
|
||||
constexpr auto kFLIterationNum = "flIterationNum";
|
||||
constexpr auto kCurIteration = "currentIteration";
|
||||
constexpr auto kJoinedClientNum = "joinedClientNum";
|
||||
constexpr auto kRejectedClientNum = "rejectedClientNum";
|
||||
constexpr auto kMetricsAuc = "metricsAuc";
|
||||
constexpr auto kMetricsLoss = "metricsLoss";
|
||||
constexpr auto kIterExecutionTime = "iterationExecutionTime";
|
||||
|
||||
const std::map<InstanceState, std::string> kInstanceStateName = {
|
||||
{InstanceState::kRunning, "running"}, {InstanceState::kDisable, "disable"}, {InstanceState::kFinish, "finish"}};
|
||||
|
||||
template <typename T>
|
||||
inline T JsonGetKeyWithException(const nlohmann::json &json, const std::string &key) {
|
||||
if (!json.contains(key)) {
|
||||
MS_LOG(EXCEPTION) << "The key " << key << "does not exist in json " << json.dump();
|
||||
}
|
||||
return json[key].get<T>();
|
||||
}
|
||||
|
||||
constexpr auto kMetrics = "metrics";
|
||||
|
||||
class IterationMetrics {
|
||||
public:
|
||||
explicit IterationMetrics(const std::string &config_file)
|
||||
: config_file_path_(config_file),
|
||||
config_(nullptr),
|
||||
fl_name_(""),
|
||||
fl_iteration_num_(0),
|
||||
cur_iteration_num_(0),
|
||||
instance_state_(InstanceState::kFinish),
|
||||
loss_(0.0),
|
||||
accuracy_(0.0),
|
||||
joined_client_num_(0),
|
||||
rejected_client_num_(0),
|
||||
iteration_time_cost_(0) {}
|
||||
~IterationMetrics() = default;
|
||||
|
||||
bool Initialize();
|
||||
|
||||
// Gather the details of this iteration and output to the persistent storage.
|
||||
bool Summarize();
|
||||
|
||||
// Clear data in persistent storage.
|
||||
bool Clear();
|
||||
|
||||
// Setters for the metrics data.
|
||||
void set_fl_name(const std::string &fl_name);
|
||||
void set_fl_iteration_num(size_t fl_iteration_num);
|
||||
void set_cur_iteration_num(size_t cur_iteration_num);
|
||||
void set_instance_state(InstanceState state);
|
||||
void set_loss(float loss);
|
||||
void set_accuracy(float acc);
|
||||
void set_joined_client_num(size_t joined_client_num);
|
||||
void set_rejected_client_num(size_t rejected_client_num);
|
||||
void set_iteration_time_cost(uint64_t iteration_time_cost);
|
||||
|
||||
private:
|
||||
bool initialized_;
|
||||
|
||||
// This is the main config file set by ps context.
|
||||
std::string config_file_path_;
|
||||
std::unique_ptr<ps::core::FileConfiguration> config_;
|
||||
|
||||
// The metrics file object.
|
||||
std::fstream metrics_file_;
|
||||
|
||||
// Json object of metrics data.
|
||||
nlohmann::json js_;
|
||||
|
||||
// The federated learning job name. Set by ps_context.
|
||||
std::string fl_name_;
|
||||
|
||||
// Federated learning iteration number. Set by ps_context.
|
||||
// If this number of iterations are completed, one instance is finished.
|
||||
size_t fl_iteration_num_;
|
||||
|
||||
// Current iteration number.
|
||||
size_t cur_iteration_num_;
|
||||
|
||||
// Current instance state.
|
||||
InstanceState instance_state_;
|
||||
|
||||
// The training loss after this federated learning iteration, passed by worker.
|
||||
float loss_;
|
||||
|
||||
// The evaluation result after this federated learning iteration, passed by worker.
|
||||
float accuracy_;
|
||||
|
||||
// The number of clients which join the federated aggregation.
|
||||
size_t joined_client_num_;
|
||||
|
||||
// The number of clients which are not involved in federated aggregation.
|
||||
size_t rejected_client_num_;
|
||||
|
||||
// The time cost in millisecond for this completed iteration.
|
||||
uint64_t iteration_time_cost_;
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_FL_SERVER_ITERATION_METRICS_H_
|
|
@ -0,0 +1,112 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <string>
|
||||
#include "fl/server/kernel/round/push_metrics_kernel.h"
|
||||
#include "fl/server/iteration.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
void PushMetricsKernel::InitKernel(size_t threshold_count) {
|
||||
local_rank_ = DistributedCountService::GetInstance().local_rank();
|
||||
}
|
||||
|
||||
bool PushMetricsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
MS_LOG(INFO) << "Launching PushMetricsKernel kernel.";
|
||||
void *req_data = inputs[0]->addr;
|
||||
std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
|
||||
if (fbb == nullptr || req_data == nullptr) {
|
||||
std::string reason = "FBBuilder builder or req_data is nullptr.";
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, reason.c_str(), reason.size());
|
||||
return true;
|
||||
}
|
||||
|
||||
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
|
||||
if (!verifier.VerifyBuffer<schema::RequestPushMetrics>()) {
|
||||
std::string reason = "The schema of RequestPushMetrics is invalid.";
|
||||
BuildPushMetricsRsp(fbb, schema::ResponseCode_RequestError);
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
const schema::RequestPushMetrics *push_metrics_req = flatbuffers::GetRoot<schema::RequestPushMetrics>(req_data);
|
||||
if (push_metrics_req == nullptr) {
|
||||
std::string reason = "Building flatbuffers schema failed for RequestPushMetrics";
|
||||
BuildPushMetricsRsp(fbb, schema::ResponseCode_RequestError);
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return false;
|
||||
}
|
||||
|
||||
ResultCode result_code = PushMetrics(fbb, push_metrics_req);
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return ConvertResultCode(result_code);
|
||||
}
|
||||
|
||||
bool PushMetricsKernel::Reset() {
|
||||
MS_LOG(INFO) << "PushMetricsKernel reset!";
|
||||
StopTimer();
|
||||
DistributedCountService::GetInstance().ResetCounter(name_);
|
||||
return true;
|
||||
}
|
||||
|
||||
void PushMetricsKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) {
|
||||
if (ps::PSContext::instance()->resetter_round() == ps::ResetterRound::kPushMetrics) {
|
||||
FinishIteration();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
ResultCode PushMetricsKernel::PushMetrics(const std::shared_ptr<FBBuilder> &fbb,
|
||||
const schema::RequestPushMetrics *push_metrics_req) {
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(fbb, ResultCode::kSuccessAndReturn);
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(push_metrics_req, ResultCode::kSuccessAndReturn);
|
||||
|
||||
float loss = push_metrics_req->loss();
|
||||
float accuracy = push_metrics_req->accuracy();
|
||||
Iteration::GetInstance().set_loss(loss);
|
||||
Iteration::GetInstance().set_accuracy(accuracy);
|
||||
|
||||
std::string count_reason = "";
|
||||
if (!DistributedCountService::GetInstance().Count(name_, std::to_string(local_rank_), &count_reason)) {
|
||||
std::string reason = "Count for push metrics request failed.";
|
||||
BuildPushMetricsRsp(fbb, schema::ResponseCode_SystemError);
|
||||
MS_LOG(ERROR) << reason;
|
||||
return count_reason == kNetworkError ? ResultCode::kFail : ResultCode::kSuccessAndReturn;
|
||||
}
|
||||
|
||||
BuildPushMetricsRsp(fbb, schema::ResponseCode_SUCCEED);
|
||||
return ResultCode::kSuccess;
|
||||
}
|
||||
|
||||
void PushMetricsKernel::BuildPushMetricsRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode) {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(fbb);
|
||||
schema::ResponsePushMetricsBuilder rsp_push_metrics_builder(*(fbb.get()));
|
||||
rsp_push_metrics_builder.add_retcode(retcode);
|
||||
auto rsp_push_metrics = rsp_push_metrics_builder.Finish();
|
||||
fbb->Finish(rsp_push_metrics);
|
||||
}
|
||||
|
||||
REG_ROUND_KERNEL(pushMetrics, PushMetricsKernel)
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_PUSH_METRICS_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_PUSH_METRICS_KERNEL_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "fl/server/common.h"
|
||||
#include "fl/server/kernel/round/round_kernel.h"
|
||||
#include "fl/server/kernel/round/round_kernel_factory.h"
|
||||
#include "fl/server/executor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
class PushMetricsKernel : public RoundKernel {
|
||||
public:
|
||||
PushMetricsKernel() = default;
|
||||
~PushMetricsKernel() override = default;
|
||||
|
||||
void InitKernel(size_t threshold_count);
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
bool Reset() override;
|
||||
void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) override;
|
||||
|
||||
private:
|
||||
ResultCode PushMetrics(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestPushMetrics *push_metrics_req);
|
||||
void BuildPushMetricsRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode);
|
||||
|
||||
uint32_t local_rank_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_PUSH_METRICS_KERNEL_H_
|
|
@ -42,6 +42,9 @@ constexpr uint32_t kWorkerSleepTimeForNetworking = 1000;
|
|||
// The time duration between retrying when server is in safemode.
|
||||
constexpr uint32_t kWorkerRetryDurationForSafeMode = 500;
|
||||
|
||||
// The leader server rank.
|
||||
constexpr uint32_t kLeaderServerRank = 0;
|
||||
|
||||
enum class IterationState {
|
||||
// This iteration is still in process.
|
||||
kRunning,
|
||||
|
|
|
@ -54,7 +54,8 @@ enum class TcpUserCommand {
|
|||
kEndLastIter,
|
||||
kStartFLJob,
|
||||
kUpdateModel,
|
||||
kGetModel
|
||||
kGetModel,
|
||||
kPushMetrics
|
||||
};
|
||||
|
||||
const std::unordered_map<TcpUserCommand, std::string> kUserCommandToMsgType = {
|
||||
|
@ -75,7 +76,8 @@ const std::unordered_map<TcpUserCommand, std::string> kUserCommandToMsgType = {
|
|||
{TcpUserCommand::kEndLastIter, "endLastIter"},
|
||||
{TcpUserCommand::kStartFLJob, "startFLJob"},
|
||||
{TcpUserCommand::kUpdateModel, "updateModel"},
|
||||
{TcpUserCommand::kGetModel, "getModel"}};
|
||||
{TcpUserCommand::kGetModel, "getModel"},
|
||||
{TcpUserCommand::kPushMetrics, "pushMetrics"}};
|
||||
|
||||
class TcpCommunicator : public CommunicatorBase {
|
||||
public:
|
||||
|
|
|
@ -46,11 +46,11 @@ constexpr char kNotEncryptType[] = "NOT_ENCRYPT";
|
|||
// 2: Server is in mixed training mode.
|
||||
// 3: Server enables pairwise encrypt algorithm.
|
||||
// For example: 1010 stands for that the server is in federated learning mode and pairwise encrypt algorithm is enabled.
|
||||
enum class ResetterRound { kNoNeedToReset, kUpdateModel, kReconstructSeccrets, kPushWeight };
|
||||
enum class ResetterRound { kNoNeedToReset, kUpdateModel, kReconstructSeccrets, kPushWeight, kPushMetrics };
|
||||
const std::map<uint32_t, ResetterRound> kServerContextToResetRoundMap = {{0b0010, ResetterRound::kUpdateModel},
|
||||
{0b1010, ResetterRound::kReconstructSeccrets},
|
||||
{0b1100, ResetterRound::kPushWeight},
|
||||
{0b0100, ResetterRound::kPushWeight}};
|
||||
{0b1100, ResetterRound::kPushMetrics},
|
||||
{0b0100, ResetterRound::kPushMetrics}};
|
||||
|
||||
class PSContext {
|
||||
public:
|
||||
|
|
|
@ -160,3 +160,12 @@ table ResponsePullWeight{
|
|||
table FeatureMapList {
|
||||
feature_map:[FeatureMap];
|
||||
}
|
||||
|
||||
table RequestPushMetrics{
|
||||
loss:float;
|
||||
accuracy:float;
|
||||
}
|
||||
|
||||
table ResponsePushMetrics{
|
||||
retcode:int;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue