commit
4b1df4d5f5
|
@ -72,21 +72,18 @@ void Iteration::InitRounds(const std::vector<std::shared_ptr<ps::core::Communica
|
|||
MS_LOG(EXCEPTION) << "Communicators for rounds is empty.";
|
||||
return;
|
||||
}
|
||||
|
||||
(void)std::for_each(communicators.begin(), communicators.end(),
|
||||
[&](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
|
||||
for (auto &round : rounds_) {
|
||||
MS_EXCEPTION_IF_NULL(round);
|
||||
round->Initialize(communicator, timeout_cb, finish_iteration_cb);
|
||||
}
|
||||
});
|
||||
|
||||
// The time window for one iteration, which will be used in some round kernels.
|
||||
size_t iteration_time_window = std::accumulate(rounds_.begin(), rounds_.end(), IntToSize(0),
|
||||
[](size_t total, const std::shared_ptr<Round> &round) {
|
||||
MS_EXCEPTION_IF_NULL(round);
|
||||
return round->check_timeout() ? total + round->time_window() : total;
|
||||
});
|
||||
size_t iteration_time_window = 0;
|
||||
for (auto &round : rounds_) {
|
||||
MS_EXCEPTION_IF_NULL(round);
|
||||
round->Initialize(timeout_cb, finish_iteration_cb);
|
||||
for (auto &communicator : communicators) {
|
||||
round->RegisterMsgCallBack(communicator);
|
||||
}
|
||||
if (round->check_timeout()) {
|
||||
iteration_time_window += round->time_window();
|
||||
}
|
||||
}
|
||||
LocalMetaStore::GetInstance().put_value(kCtxTotalTimeoutDuration, iteration_time_window);
|
||||
MS_LOG(INFO) << "Time window for one iteration is " << iteration_time_window;
|
||||
|
||||
|
|
|
@ -160,7 +160,7 @@ class FedAvgKernel : public AggregationKernel {
|
|||
bool IsAggregationDone() override { return done_; }
|
||||
|
||||
void SetParameterAddress(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
weight_addr_ = inputs[0];
|
||||
data_size_addr_ = inputs[1];
|
||||
new_weight_addr_ = inputs[2];
|
||||
|
|
|
@ -131,31 +131,25 @@ bool ClientListKernel::DealClient(const size_t iter_num, const schema::GetClient
|
|||
return true;
|
||||
}
|
||||
|
||||
bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
bool ClientListKernel::Launch(const uint8_t *req_data, size_t len,
|
||||
const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
||||
MS_LOG(INFO) << "Launching ClientListKernel, Iteration number is " << iter_num;
|
||||
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||
std::string reason = "inputs or outputs size is invalid.";
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
}
|
||||
|
||||
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
||||
void *req_data = inputs[0]->addr;
|
||||
if (fbb == nullptr || req_data == nullptr) {
|
||||
std::string reason = "FBBuilder builder or req_data is nullptr.";
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
}
|
||||
std::vector<string> client_list;
|
||||
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
|
||||
flatbuffers::Verifier verifier(req_data, len);
|
||||
if (!verifier.VerifyBuffer<schema::GetClientList>()) {
|
||||
std::string reason = "The schema of GetClientList is invalid.";
|
||||
BuildClientListRsp(fbb, schema::ResponseCode_RequestError, reason, client_list,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
const schema::GetClientList *get_clients_req = flatbuffers::GetRoot<schema::GetClientList>(req_data);
|
||||
|
@ -164,7 +158,7 @@ bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
|
|||
BuildClientListRsp(fbb, schema::ResponseCode_RequestError, reason, client_list,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
// verify signature
|
||||
|
@ -175,7 +169,7 @@ bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
|
|||
BuildClientListRsp(fbb, schema::ResponseCode_RequestError, reason, client_list,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
if (verify_result == sigVerifyResult::TIMEOUT) {
|
||||
|
@ -183,7 +177,7 @@ bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
|
|||
BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, reason, client_list,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
MS_LOG(DEBUG) << "verify signature passed!";
|
||||
|
@ -195,7 +189,7 @@ bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
|
|||
<< ". client request iteration is " << iter_client;
|
||||
BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, "iter num is error.", client_list,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -206,7 +200,7 @@ bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
|
|||
if (!DealClient(iter_num, get_clients_req, fbb)) {
|
||||
MS_LOG(WARNING) << "Get Client List not ready.";
|
||||
}
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
} // namespace fl
|
||||
|
||||
|
|
|
@ -37,8 +37,7 @@ class ClientListKernel : public RoundKernel {
|
|||
ClientListKernel() = default;
|
||||
~ClientListKernel() override = default;
|
||||
void InitKernel(size_t required_cnt) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
bool Launch(const uint8_t *req_data, size_t len, const std::shared_ptr<ps::core::MessageHandler> &message) override;
|
||||
bool Reset() override;
|
||||
void BuildClientListRsp(const std::shared_ptr<server::FBBuilder> &fbb, const schema::ResponseCode retcode,
|
||||
const string &reason, std::vector<std::string> clients, const string &next_req_time,
|
||||
|
|
|
@ -123,18 +123,12 @@ sigVerifyResult ExchangeKeysKernel::VerifySignature(const schema::RequestExchang
|
|||
return sigVerifyResult::PASSED;
|
||||
}
|
||||
|
||||
bool ExchangeKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
bool ExchangeKeysKernel::Launch(const uint8_t *req_data, size_t len,
|
||||
const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
||||
MS_LOG(INFO) << "Launching ExchangeKey kernel, ITERATION NUMBER IS : " << iter_num;
|
||||
bool response = false;
|
||||
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||
std::string reason = "inputs or outputs size is invalid.";
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
}
|
||||
|
||||
void *req_data = inputs[0]->addr;
|
||||
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
||||
if (fbb == nullptr || req_data == nullptr) {
|
||||
std::string reason = "FBBuilder builder or req_data is nullptr.";
|
||||
|
@ -143,17 +137,17 @@ bool ExchangeKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std
|
|||
}
|
||||
|
||||
if (ReachThresholdForExchangeKeys(fbb, iter_num)) {
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
|
||||
flatbuffers::Verifier verifier(req_data, len);
|
||||
if (!verifier.VerifyBuffer<schema::RequestExchangeKeys>()) {
|
||||
std::string reason = "The schema of RequestExchangeKeys is invalid.";
|
||||
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_RequestError, reason,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
const schema::RequestExchangeKeys *exchange_keys_req = flatbuffers::GetRoot<schema::RequestExchangeKeys>(req_data);
|
||||
|
@ -162,7 +156,7 @@ bool ExchangeKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std
|
|||
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_RequestError, reason,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -174,7 +168,7 @@ bool ExchangeKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std
|
|||
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_RequestError, reason,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -183,7 +177,7 @@ bool ExchangeKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std
|
|||
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_OutOfTime, reason,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -196,21 +190,21 @@ bool ExchangeKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std
|
|||
<< ". client request iteration is " << iter_client;
|
||||
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_OutOfTime, "iter num is error.",
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
response = cipher_key_->ExchangeKeys(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), exchange_keys_req, fbb);
|
||||
if (!response) {
|
||||
MS_LOG(ERROR) << "update exchange keys is failed.";
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
if (!CountForExchangeKeys(fbb, exchange_keys_req, iter_num)) {
|
||||
MS_LOG(ERROR) << "count for exchange keys failed.";
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -39,8 +39,7 @@ class ExchangeKeysKernel : public RoundKernel {
|
|||
ExchangeKeysKernel() = default;
|
||||
~ExchangeKeysKernel() override = default;
|
||||
void InitKernel(size_t required_cnt) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
bool Launch(const uint8_t *req_data, size_t len, const std::shared_ptr<ps::core::MessageHandler> &message) override;
|
||||
bool Reset() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -90,19 +90,12 @@ sigVerifyResult GetKeysKernel::VerifySignature(const schema::GetExchangeKeys *ge
|
|||
return sigVerifyResult::PASSED;
|
||||
}
|
||||
|
||||
bool GetKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
bool GetKeysKernel::Launch(const uint8_t *req_data, size_t len,
|
||||
const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
||||
MS_LOG(INFO) << "Launching GetKeys kernel, ITERATION NUMBER IS : " << iter_num;
|
||||
bool response = false;
|
||||
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||
std::string reason = "inputs or outputs size is invalid.";
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
}
|
||||
|
||||
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
||||
void *req_data = inputs[0]->addr;
|
||||
if (fbb == nullptr || req_data == nullptr) {
|
||||
std::string reason = "FBBuilder builder or req_data is nullptr.";
|
||||
MS_LOG(ERROR) << reason;
|
||||
|
@ -111,13 +104,13 @@ bool GetKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vec
|
|||
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||
MS_LOG(WARNING) << "Current amount for GetKeysKernel is enough.";
|
||||
}
|
||||
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
|
||||
flatbuffers::Verifier verifier(req_data, len);
|
||||
if (!verifier.VerifyBuffer<schema::GetExchangeKeys>()) {
|
||||
std::string reason = "The schema of GetExchangeKeys is invalid.";
|
||||
cipher_key_->BuildGetKeysRsp(fbb, schema::ResponseCode_RequestError, iter_num,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), false);
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
const schema::GetExchangeKeys *get_exchange_keys_req = flatbuffers::GetRoot<schema::GetExchangeKeys>(req_data);
|
||||
|
@ -126,7 +119,7 @@ bool GetKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vec
|
|||
cipher_key_->BuildGetKeysRsp(fbb, schema::ResponseCode_RequestError, iter_num,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), false);
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -138,7 +131,7 @@ bool GetKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vec
|
|||
cipher_key_->BuildGetKeysRsp(fbb, schema::ResponseCode_RequestError, iter_num,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), false);
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -147,7 +140,7 @@ bool GetKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vec
|
|||
cipher_key_->BuildGetKeysRsp(fbb, schema::ResponseCode_OutOfTime, iter_num,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), false);
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
MS_LOG(INFO) << "verify signature passed!";
|
||||
|
@ -159,20 +152,20 @@ bool GetKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vec
|
|||
<< ". client request iteration is " << iter_client;
|
||||
cipher_key_->BuildGetKeysRsp(fbb, schema::ResponseCode_OutOfTime, iter_num,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), false);
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
response = cipher_key_->GetKeys(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), get_exchange_keys_req, fbb);
|
||||
if (!response) {
|
||||
MS_LOG(WARNING) << "get public keys not ready.";
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
if (!CountForGetKeys(fbb, get_exchange_keys_req, iter_num)) {
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
GenerateOutput(outputs, fbb->GetCurrentBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetCurrentBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -38,8 +38,7 @@ class GetKeysKernel : public RoundKernel {
|
|||
GetKeysKernel() = default;
|
||||
~GetKeysKernel() override = default;
|
||||
void InitKernel(size_t required_cnt) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
bool Launch(const uint8_t *req_data, size_t len, const std::shared_ptr<ps::core::MessageHandler> &message) override;
|
||||
bool Reset() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -78,30 +78,24 @@ sigVerifyResult GetListSignKernel::VerifySignature(const schema::RequestAllClien
|
|||
return sigVerifyResult::PASSED;
|
||||
}
|
||||
|
||||
bool GetListSignKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
bool GetListSignKernel::Launch(const uint8_t *req_data, size_t len,
|
||||
const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
||||
MS_LOG(INFO) << "Launching GetListSign kernel, Iteration number is " << iter_num;
|
||||
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||
std::string reason = "inputs or outputs size is invalid.";
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
}
|
||||
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
||||
void *req_data = inputs[0]->addr;
|
||||
if (fbb == nullptr || req_data == nullptr) {
|
||||
std::string reason = "FBBuilder builder or req_data is nullptr.";
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
}
|
||||
std::map<std::string, std::vector<unsigned char>> list_signs;
|
||||
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
|
||||
flatbuffers::Verifier verifier(req_data, len);
|
||||
if (!verifier.VerifyBuffer<schema::RequestAllClientListSign>()) {
|
||||
std::string reason = "The schema of RequestAllClientListSign is invalid.";
|
||||
BuildGetListSignKernelRsp(fbb, schema::ResponseCode_RequestError, reason,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num, list_signs);
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
const schema::RequestAllClientListSign *get_list_sign_req =
|
||||
|
@ -111,7 +105,7 @@ bool GetListSignKernel::Launch(const std::vector<AddressPtr> &inputs, const std:
|
|||
BuildGetListSignKernelRsp(fbb, schema::ResponseCode_RequestError, reason,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num, list_signs);
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -123,7 +117,7 @@ bool GetListSignKernel::Launch(const std::vector<AddressPtr> &inputs, const std:
|
|||
BuildGetListSignKernelRsp(fbb, schema::ResponseCode_RequestError, reason,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num, list_signs);
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -132,7 +126,7 @@ bool GetListSignKernel::Launch(const std::vector<AddressPtr> &inputs, const std:
|
|||
BuildGetListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(CURRENT_TIME_MILLI.count()),
|
||||
iter_num, list_signs);
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -147,7 +141,7 @@ bool GetListSignKernel::Launch(const std::vector<AddressPtr> &inputs, const std:
|
|||
<< ". client request iteration is " << iter_client;
|
||||
BuildGetListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, "iter num is error.",
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num, list_signs);
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
std::string fl_id = get_list_sign_req->fl_id()->str();
|
||||
|
@ -156,7 +150,7 @@ bool GetListSignKernel::Launch(const std::vector<AddressPtr> &inputs, const std:
|
|||
}
|
||||
if (!GetListSign(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), get_list_sign_req, fbb)) {
|
||||
MS_LOG(WARNING) << "get list signs not ready.";
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
std::string count_reason = "";
|
||||
|
@ -167,7 +161,7 @@ bool GetListSignKernel::Launch(const std::vector<AddressPtr> &inputs, const std:
|
|||
MS_LOG(ERROR) << reason;
|
||||
return true;
|
||||
}
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -38,8 +38,7 @@ class GetListSignKernel : public RoundKernel {
|
|||
GetListSignKernel() = default;
|
||||
~GetListSignKernel() override = default;
|
||||
void InitKernel(size_t required_cnt) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
bool Launch(const uint8_t *req_data, size_t len, const std::shared_ptr<ps::core::MessageHandler> &message) override;
|
||||
bool Reset() override;
|
||||
void BuildGetListSignKernelRsp(const std::shared_ptr<server::FBBuilder> &fbb, const schema::ResponseCode retcode,
|
||||
const string &reason, const string &next_req_time, const size_t iteration,
|
||||
|
|
|
@ -39,31 +39,23 @@ void GetModelKernel::InitKernel(size_t) {
|
|||
}
|
||||
}
|
||||
|
||||
bool GetModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||
std::string reason = "inputs or outputs size is invalid.";
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, reason.c_str(), reason.size());
|
||||
return true;
|
||||
}
|
||||
|
||||
void *req_data = inputs[0]->addr;
|
||||
bool GetModelKernel::Launch(const uint8_t *req_data, size_t len,
|
||||
const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
|
||||
if (fbb == nullptr || req_data == nullptr) {
|
||||
std::string reason = "FBBuilder builder or req_data is nullptr.";
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, reason.c_str(), reason.size());
|
||||
GenerateOutput(message, reason.c_str(), reason.size());
|
||||
return true;
|
||||
}
|
||||
|
||||
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
|
||||
flatbuffers::Verifier verifier(req_data, len);
|
||||
if (!verifier.VerifyBuffer<schema::RequestGetModel>()) {
|
||||
std::string reason = "The schema of RequestGetModel is invalid.";
|
||||
BuildGetModelRsp(fbb, schema::ResponseCode_RequestError, reason, LocalMetaStore::GetInstance().curr_iter_num(), {},
|
||||
"");
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -76,11 +68,11 @@ bool GetModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std::ve
|
|||
if (get_model_req == nullptr) {
|
||||
std::string reason = "Building flatbuffers schema failed for RequestGetModel.";
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, reason.c_str(), reason.size());
|
||||
GenerateOutput(message, reason.c_str(), reason.size());
|
||||
return true;
|
||||
}
|
||||
GetModel(get_model_req, fbb);
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -37,8 +37,7 @@ class GetModelKernel : public RoundKernel {
|
|||
~GetModelKernel() override = default;
|
||||
|
||||
void InitKernel(size_t) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
bool Launch(const uint8_t *req_data, size_t len, const std::shared_ptr<ps::core::MessageHandler> &message) override;
|
||||
bool Reset() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -91,31 +91,23 @@ sigVerifyResult GetSecretsKernel::VerifySignature(const schema::GetShareSecrets
|
|||
return sigVerifyResult::PASSED;
|
||||
}
|
||||
|
||||
bool GetSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
bool GetSecretsKernel::Launch(const uint8_t *req_data, size_t len,
|
||||
const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
||||
std::string next_timestamp = std::to_string(CURRENT_TIME_MILLI.count());
|
||||
MS_LOG(INFO) << "Launching get secrets kernel, ITERATION NUMBER IS : " << iter_num;
|
||||
|
||||
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||
std::string reason = "inputs or outputs size is invalid.";
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
}
|
||||
|
||||
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
||||
void *req_data = inputs[0]->addr;
|
||||
if (fbb == nullptr || req_data == nullptr) {
|
||||
std::string reason = "FBBuilder builder or req_data is nullptr.";
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
}
|
||||
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
|
||||
flatbuffers::Verifier verifier(req_data, len);
|
||||
if (!verifier.VerifyBuffer<schema::GetShareSecrets>()) {
|
||||
std::string reason = "The schema of GetShareSecrets is invalid.";
|
||||
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_RequestError, iter_num, next_timestamp, nullptr);
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
const schema::GetShareSecrets *get_secrets_req = flatbuffers::GetRoot<schema::GetShareSecrets>(req_data);
|
||||
|
@ -123,7 +115,7 @@ bool GetSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
|
|||
std::string reason = "Building flatbuffers schema failed for GetExchangeKeys.";
|
||||
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_RequestError, iter_num, next_timestamp, nullptr);
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -134,7 +126,7 @@ bool GetSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
|
|||
std::string reason = "verify signature failed.";
|
||||
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_RequestError, iter_num, next_timestamp, nullptr);
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -142,7 +134,7 @@ bool GetSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
|
|||
std::string reason = "verify signature timestamp failed.";
|
||||
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_OutOfTime, iter_num, next_timestamp, nullptr);
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -155,7 +147,7 @@ bool GetSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
|
|||
MS_LOG(ERROR) << "GetSecretsKernel iteration invalid. server now iteration is " << iter_num
|
||||
<< ". client request iteration is " << iter_client;
|
||||
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_OutOfTime, iter_num, next_timestamp, nullptr);
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -166,14 +158,14 @@ bool GetSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
|
|||
bool response = cipher_share_->GetSecrets(get_secrets_req, fbb, next_timestamp);
|
||||
if (!response) {
|
||||
MS_LOG(WARNING) << "get secret shares not ready.";
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
if (!CountForGetSecrets(fbb, get_secrets_req, iter_num)) {
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -37,8 +37,7 @@ class GetSecretsKernel : public RoundKernel {
|
|||
GetSecretsKernel() = default;
|
||||
~GetSecretsKernel() override = default;
|
||||
void InitKernel(size_t required_cnt) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
bool Launch(const uint8_t *req_data, size_t len, const std::shared_ptr<ps::core::MessageHandler> &message) override;
|
||||
bool Reset() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -34,17 +34,9 @@ void PullWeightKernel::InitKernel(size_t) {
|
|||
}
|
||||
}
|
||||
|
||||
bool PullWeightKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
bool PullWeightKernel::Launch(const uint8_t *req_data, size_t len,
|
||||
const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
MS_LOG(DEBUG) << "Launching PullWeightKernel kernel.";
|
||||
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||
std::string reason = "inputs or outputs size is invalid.";
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, reason.c_str(), reason.size());
|
||||
return true;
|
||||
}
|
||||
|
||||
void *req_data = inputs[0]->addr;
|
||||
std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
|
||||
if (fbb == nullptr || req_data == nullptr) {
|
||||
MS_LOG(ERROR) << "FBBuilder builder or req_data is nullptr.";
|
||||
|
@ -56,12 +48,12 @@ bool PullWeightKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
|
|||
std::string reason = "Building flatbuffers schema failed for RequestPullWeight";
|
||||
BuildPullWeightRsp(fbb, schema::ResponseCode_RequestError, reason, LocalMetaStore::GetInstance().curr_iter_num(),
|
||||
{});
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return false;
|
||||
}
|
||||
|
||||
PullWeight(fbb, pull_weight_req);
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -37,8 +37,7 @@ class PullWeightKernel : public RoundKernel {
|
|||
~PullWeightKernel() override = default;
|
||||
|
||||
void InitKernel(size_t required_cnt) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
bool Launch(const uint8_t *req_data, size_t len, const std::shared_ptr<ps::core::MessageHandler> &message) override;
|
||||
bool Reset() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -33,29 +33,23 @@ void PushListSignKernel::InitKernel(size_t) {
|
|||
cipher_init_ = &armour::CipherInit::GetInstance();
|
||||
}
|
||||
|
||||
bool PushListSignKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
bool PushListSignKernel::Launch(const uint8_t *req_data, size_t len,
|
||||
const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
||||
MS_LOG(INFO) << "Launching PushListSignKernel, Iteration number is " << iter_num;
|
||||
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||
std::string reason = "inputs or outputs size is invalid.";
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
}
|
||||
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
||||
void *req_data = inputs[0]->addr;
|
||||
if (fbb == nullptr || req_data == nullptr) {
|
||||
std::string reason = "FBBuilder builder or req_data is nullptr.";
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
}
|
||||
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
|
||||
flatbuffers::Verifier verifier(req_data, len);
|
||||
if (!verifier.VerifyBuffer<schema::SendClientListSign>()) {
|
||||
std::string reason = "The schema of PushClientListSign is invalid.";
|
||||
BuildPushListSignKernelRsp(fbb, schema::ResponseCode_RequestError, reason,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
const schema::SendClientListSign *client_list_sign_req = flatbuffers::GetRoot<schema::SendClientListSign>(req_data);
|
||||
|
@ -64,7 +58,7 @@ bool PushListSignKernel::Launch(const std::vector<AddressPtr> &inputs, const std
|
|||
BuildPushListSignKernelRsp(fbb, schema::ResponseCode_RequestError, reason,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
// verify signature
|
||||
|
@ -75,7 +69,7 @@ bool PushListSignKernel::Launch(const std::vector<AddressPtr> &inputs, const std
|
|||
BuildPushListSignKernelRsp(fbb, schema::ResponseCode_RequestError, reason,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
if (verify_result == sigVerifyResult::TIMEOUT) {
|
||||
|
@ -83,17 +77,17 @@ bool PushListSignKernel::Launch(const std::vector<AddressPtr> &inputs, const std
|
|||
BuildPushListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
MS_LOG(INFO) << "verify signature passed!";
|
||||
}
|
||||
return LaunchForPushListSign(client_list_sign_req, iter_num, fbb, outputs);
|
||||
return LaunchForPushListSign(client_list_sign_req, iter_num, fbb, message);
|
||||
}
|
||||
|
||||
bool PushListSignKernel::LaunchForPushListSign(const schema::SendClientListSign *client_list_sign_req,
|
||||
const size_t &iter_num, const std::shared_ptr<server::FBBuilder> &fbb,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(client_list_sign_req, false);
|
||||
size_t iter_client = IntToSize(client_list_sign_req->iteration());
|
||||
if (iter_num != iter_client) {
|
||||
|
@ -102,7 +96,7 @@ bool PushListSignKernel::LaunchForPushListSign(const schema::SendClientListSign
|
|||
MS_LOG(WARNING) << "server now iteration is " << iter_num << ". client request iteration is " << iter_client;
|
||||
BuildPushListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(CURRENT_TIME_MILLI.count()),
|
||||
iter_num);
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
std::vector<string> update_model_clients;
|
||||
|
@ -125,13 +119,13 @@ bool PushListSignKernel::LaunchForPushListSign(const schema::SendClientListSign
|
|||
"Current amount for PushListSignKernel is enough.",
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
}
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
if (!PushListSign(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), client_list_sign_req, fbb,
|
||||
update_model_clients)) {
|
||||
MS_LOG(ERROR) << "push client list sign failed.";
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
std::string count_reason = "";
|
||||
|
@ -140,9 +134,10 @@ bool PushListSignKernel::LaunchForPushListSign(const schema::SendClientListSign
|
|||
BuildPushListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(CURRENT_TIME_MILLI.count()),
|
||||
iter_num);
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -38,10 +38,10 @@ class PushListSignKernel : public RoundKernel {
|
|||
PushListSignKernel() = default;
|
||||
~PushListSignKernel() override = default;
|
||||
void InitKernel(size_t required_cnt) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
bool Launch(const uint8_t *req_data, size_t len, const std::shared_ptr<ps::core::MessageHandler> &message) override;
|
||||
bool LaunchForPushListSign(const schema::SendClientListSign *client_list_sign_req, const size_t &iter_num,
|
||||
const std::shared_ptr<server::FBBuilder> &fbb, const std::vector<AddressPtr> &outputs);
|
||||
const std::shared_ptr<server::FBBuilder> &fbb,
|
||||
const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||
bool Reset() override;
|
||||
void BuildPushListSignKernelRsp(const std::shared_ptr<server::FBBuilder> &fbb, const schema::ResponseCode retcode,
|
||||
const string &reason, const string &next_req_time, const size_t iteration);
|
||||
|
|
|
@ -24,24 +24,23 @@ namespace server {
|
|||
namespace kernel {
|
||||
void PushMetricsKernel::InitKernel(size_t) { local_rank_ = DistributedCountService::GetInstance().local_rank(); }
|
||||
|
||||
bool PushMetricsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
bool PushMetricsKernel::Launch(const uint8_t *req_data, size_t len,
|
||||
const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
MS_LOG(INFO) << "Launching PushMetricsKernel kernel.";
|
||||
void *req_data = inputs[0]->addr;
|
||||
std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
|
||||
if (fbb == nullptr || req_data == nullptr) {
|
||||
std::string reason = "FBBuilder builder or req_data is nullptr.";
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, reason.c_str(), reason.size());
|
||||
GenerateOutput(message, reason.c_str(), reason.size());
|
||||
return true;
|
||||
}
|
||||
|
||||
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
|
||||
flatbuffers::Verifier verifier(req_data, len);
|
||||
if (!verifier.VerifyBuffer<schema::RequestPushMetrics>()) {
|
||||
std::string reason = "The schema of RequestPushMetrics is invalid.";
|
||||
BuildPushMetricsRsp(fbb, schema::ResponseCode_RequestError);
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -50,12 +49,12 @@ bool PushMetricsKernel::Launch(const std::vector<AddressPtr> &inputs, const std:
|
|||
std::string reason = "Building flatbuffers schema failed for RequestPushMetrics";
|
||||
BuildPushMetricsRsp(fbb, schema::ResponseCode_RequestError);
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return false;
|
||||
}
|
||||
|
||||
ResultCode result_code = PushMetrics(fbb, push_metrics_req);
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return ConvertResultCode(result_code);
|
||||
}
|
||||
|
||||
|
|
|
@ -36,8 +36,7 @@ class PushMetricsKernel : public RoundKernel {
|
|||
~PushMetricsKernel() override = default;
|
||||
|
||||
void InitKernel(size_t threshold_count);
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
bool Launch(const uint8_t *req_data, size_t len, const std::shared_ptr<ps::core::MessageHandler> &message) override;
|
||||
bool Reset() override;
|
||||
void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) override;
|
||||
|
||||
|
|
|
@ -30,31 +30,23 @@ void PushWeightKernel::InitKernel(size_t) {
|
|||
local_rank_ = DistributedCountService::GetInstance().local_rank();
|
||||
}
|
||||
|
||||
bool PushWeightKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
bool PushWeightKernel::Launch(const uint8_t *req_data, size_t len,
|
||||
const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
MS_LOG(INFO) << "Launching PushWeightKernel kernel.";
|
||||
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||
std::string reason = "inputs or outputs size is invalid.";
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, reason.c_str(), reason.size());
|
||||
return true;
|
||||
}
|
||||
|
||||
void *req_data = inputs[0]->addr;
|
||||
std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
|
||||
if (fbb == nullptr || req_data == nullptr) {
|
||||
std::string reason = "FBBuilder builder or req_data is nullptr.";
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, reason.c_str(), reason.size());
|
||||
GenerateOutput(message, reason.c_str(), reason.size());
|
||||
return true;
|
||||
}
|
||||
|
||||
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
|
||||
flatbuffers::Verifier verifier(req_data, len);
|
||||
if (!verifier.VerifyBuffer<schema::RequestPushWeight>()) {
|
||||
std::string reason = "The schema of RequestPushWeight is invalid.";
|
||||
BuildPushWeightRsp(fbb, schema::ResponseCode_RequestError, reason, LocalMetaStore::GetInstance().curr_iter_num());
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -62,12 +54,12 @@ bool PushWeightKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
|
|||
if (push_weight_req == nullptr) {
|
||||
std::string reason = "Building flatbuffers schema failed for RequestPushWeight";
|
||||
BuildPushWeightRsp(fbb, schema::ResponseCode_RequestError, reason, LocalMetaStore::GetInstance().curr_iter_num());
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return false;
|
||||
}
|
||||
|
||||
ResultCode result_code = PushWeight(fbb, push_weight_req);
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return ConvertResultCode(result_code);
|
||||
}
|
||||
|
||||
|
|
|
@ -36,8 +36,7 @@ class PushWeightKernel : public RoundKernel {
|
|||
~PushWeightKernel() override = default;
|
||||
|
||||
void InitKernel(size_t threshold_count) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
bool Launch(const uint8_t *req_data, size_t len, const std::shared_ptr<ps::core::MessageHandler> &message) override;
|
||||
bool Reset() override;
|
||||
void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) override;
|
||||
|
||||
|
|
|
@ -90,20 +90,13 @@ sigVerifyResult ReconstructSecretsKernel::VerifySignature(const schema::SendReco
|
|||
return sigVerifyResult::PASSED;
|
||||
}
|
||||
|
||||
bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
bool ReconstructSecretsKernel::Launch(const uint8_t *req_data, size_t len,
|
||||
const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
bool response = false;
|
||||
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
||||
MS_LOG(INFO) << "Launching ReconstructSecrets Kernel, Iteration number is " << iter_num;
|
||||
|
||||
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||
MS_LOG(ERROR) << "ReconstructSecretsKernel needs 1 input, but got " << inputs.size();
|
||||
return false;
|
||||
}
|
||||
|
||||
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
||||
void *req_data = inputs[0]->addr;
|
||||
|
||||
if (fbb == nullptr || req_data == nullptr) {
|
||||
std::string reason = "FBBuilder builder or req_data is nullptr.";
|
||||
MS_LOG(ERROR) << reason;
|
||||
|
@ -119,13 +112,13 @@ bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, con
|
|||
for (size_t i = 0; i < IntToSize(update_model_clients_pb.fl_id_size()); ++i) {
|
||||
update_model_clients.push_back(update_model_clients_pb.fl_id(SizeToInt(i)));
|
||||
}
|
||||
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
|
||||
flatbuffers::Verifier verifier(req_data, len);
|
||||
if (!verifier.VerifyBuffer<schema::SendReconstructSecret>()) {
|
||||
std::string reason = "The schema of SendReconstructSecret is invalid.";
|
||||
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_RequestError, reason, SizeToInt(iter_num),
|
||||
std::to_string(CURRENT_TIME_MILLI.count()));
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
const schema::SendReconstructSecret *reconstruct_secret_req =
|
||||
|
@ -135,7 +128,7 @@ bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, con
|
|||
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_RequestError, reason, SizeToInt(iter_num),
|
||||
std::to_string(CURRENT_TIME_MILLI.count()));
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
// verify signature
|
||||
|
@ -146,7 +139,7 @@ bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, con
|
|||
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_RequestError, reason,
|
||||
SizeToInt(iter_num), std::to_string(CURRENT_TIME_MILLI.count()));
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -155,7 +148,7 @@ bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, con
|
|||
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_OutOfTime, reason, SizeToInt(iter_num),
|
||||
std::to_string(CURRENT_TIME_MILLI.count()));
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
MS_LOG(INFO) << "verify signature passed!";
|
||||
|
@ -174,7 +167,7 @@ bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, con
|
|||
"Current amount for ReconstructSecretsKernel is enough.",
|
||||
SizeToInt(iter_num), std::to_string(CURRENT_TIME_MILLI.count()));
|
||||
}
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -186,7 +179,7 @@ bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, con
|
|||
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||
MS_LOG(INFO) << "Current amount for ReconstructSecretsKernel is enough.";
|
||||
}
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
|
||||
MS_LOG(INFO) << "reconstruct_secrets_kernel success.";
|
||||
if (!response) {
|
||||
|
|
|
@ -39,8 +39,7 @@ class ReconstructSecretsKernel : public RoundKernel {
|
|||
~ReconstructSecretsKernel() override = default;
|
||||
|
||||
void InitKernel(size_t required_cnt) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
bool Launch(const uint8_t *req_data, size_t len, const std::shared_ptr<ps::core::MessageHandler> &message) override;
|
||||
bool Reset() override;
|
||||
void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) override;
|
||||
|
||||
|
|
|
@ -27,39 +27,9 @@ namespace mindspore {
|
|||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
RoundKernel::RoundKernel() : name_(""), current_count_(0), required_count_(0), error_reason_(""), running_(true) {
|
||||
release_thread_ = std::thread([&]() {
|
||||
while (running_.load()) {
|
||||
std::unique_lock<std::mutex> release_lock(release_mtx_);
|
||||
// Detect whether there's any data needs to be released every 100 milliseconds.
|
||||
if (heap_data_to_release_.empty()) {
|
||||
release_lock.unlock();
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(kReleaseDuration));
|
||||
continue;
|
||||
}
|
||||
RoundKernel::RoundKernel() : name_(""), current_count_(0) {}
|
||||
|
||||
AddressPtr addr_ptr = heap_data_to_release_.front();
|
||||
heap_data_to_release_.pop();
|
||||
release_lock.unlock();
|
||||
|
||||
std::unique_lock<std::mutex> heap_data_lock(heap_data_mtx_);
|
||||
if (heap_data_.count(addr_ptr) == 0) {
|
||||
MS_LOG(ERROR) << "The data is not stored.";
|
||||
continue;
|
||||
}
|
||||
// Manually release unique_ptr data.
|
||||
heap_data_[addr_ptr].reset(nullptr);
|
||||
(void)heap_data_.erase(heap_data_.find(addr_ptr));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
RoundKernel::~RoundKernel() {
|
||||
running_ = false;
|
||||
if (release_thread_.joinable()) {
|
||||
release_thread_.join();
|
||||
}
|
||||
}
|
||||
RoundKernel::~RoundKernel() {}
|
||||
|
||||
void RoundKernel::OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) { return; }
|
||||
|
||||
|
@ -79,16 +49,6 @@ void RoundKernel::FinishIteration() const {
|
|||
return;
|
||||
}
|
||||
|
||||
void RoundKernel::Release(const AddressPtr &addr_ptr) {
|
||||
if (addr_ptr == nullptr) {
|
||||
MS_LOG(WARNING) << "Data to be released is empty.";
|
||||
return;
|
||||
}
|
||||
std::unique_lock<std::mutex> lock(release_mtx_);
|
||||
heap_data_to_release_.push(addr_ptr);
|
||||
return;
|
||||
}
|
||||
|
||||
void RoundKernel::set_name(const std::string &name) { name_ = name; }
|
||||
|
||||
void RoundKernel::set_stop_timer_cb(const StopTimerCb &timer_stopper) { stop_timer_cb_ = timer_stopper; }
|
||||
|
@ -97,36 +57,26 @@ void RoundKernel::set_finish_iteration_cb(const FinishIterCb &finish_iteration_c
|
|||
finish_iteration_cb_ = finish_iteration_cb;
|
||||
}
|
||||
|
||||
void RoundKernel::GenerateOutput(const std::vector<AddressPtr> &outputs, const void *data, size_t len) {
|
||||
if (data == nullptr) {
|
||||
MS_LOG(WARNING) << "The data is nullptr.";
|
||||
void RoundKernel::GenerateOutput(const std::shared_ptr<ps::core::MessageHandler> &message, const void *data,
|
||||
size_t len) {
|
||||
if (message == nullptr) {
|
||||
MS_LOG(WARNING) << "The message handler is nullptr.";
|
||||
return;
|
||||
}
|
||||
|
||||
if (outputs.empty()) {
|
||||
MS_LOG(WARNING) << "Generating output failed. Outputs size is empty.";
|
||||
if (data == nullptr || len == 0) {
|
||||
std::string reason = "The output of the round " + name_ + " is empty.";
|
||||
MS_LOG(WARNING) << reason;
|
||||
if (!message->SendResponse(reason.c_str(), reason.size())) {
|
||||
MS_LOG(WARNING) << "Sending response failed.";
|
||||
return;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
std::unique_ptr<unsigned char[]> output_data = std::make_unique<unsigned char[]>(len);
|
||||
if (output_data == nullptr) {
|
||||
MS_LOG(WARNING) << "Output data is nullptr.";
|
||||
return;
|
||||
}
|
||||
|
||||
size_t dst_size = len;
|
||||
int ret = memcpy_s(output_data.get(), dst_size, data, len);
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
|
||||
return;
|
||||
}
|
||||
outputs[0]->addr = output_data.get();
|
||||
outputs[0]->size = len;
|
||||
|
||||
std::unique_lock<std::mutex> lock(heap_data_mtx_);
|
||||
(void)heap_data_.insert(std::make_pair(outputs[0], std::move(output_data)));
|
||||
IncreaseTotalClientNum();
|
||||
return;
|
||||
if (!message->SendResponse(data, len)) {
|
||||
MS_LOG(WARNING) << "Sending response failed.";
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
void RoundKernel::IncreaseTotalClientNum() { total_client_num_ += 1; }
|
||||
|
|
|
@ -45,22 +45,18 @@ constexpr uint64_t kReleaseDuration = 100;
|
|||
|
||||
// For example, the main process of federated learning is:
|
||||
// startFLJob round->updateModel round->getModel round.
|
||||
class RoundKernel : virtual public CPUKernel {
|
||||
class RoundKernel {
|
||||
public:
|
||||
RoundKernel();
|
||||
virtual ~RoundKernel();
|
||||
|
||||
// RoundKernel doesn't use InitKernel method of base class CPUKernel to initialize. So implementation of this
|
||||
// inherited method is empty.
|
||||
void InitKernel(const CNodePtr &kernel_node) override {}
|
||||
|
||||
// Initialize RoundKernel with threshold_count which means that for every iteration, this round needs threshold_count
|
||||
// messages.
|
||||
virtual void InitKernel(size_t threshold_count) = 0;
|
||||
|
||||
// Launch the round kernel logic to handle the message passed by the communication module.
|
||||
virtual bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) = 0;
|
||||
virtual bool Launch(const uint8_t *req_data, size_t len,
|
||||
const std::shared_ptr<ps::core::MessageHandler> &message) = 0;
|
||||
|
||||
// Some rounds could be stateful in a iteration. Reset method resets the status of this round.
|
||||
virtual bool Reset() = 0;
|
||||
|
@ -78,10 +74,6 @@ class RoundKernel : virtual public CPUKernel {
|
|||
// be called.
|
||||
void FinishIteration() const;
|
||||
|
||||
// Release the response data allocated inside the round kernel.
|
||||
// Server framework must call this after the response data is sent back.
|
||||
void Release(const AddressPtr &addr_ptr);
|
||||
|
||||
// Set round kernel name, which could be used in round kernel's methods.
|
||||
void set_name(const std::string &name);
|
||||
|
||||
|
@ -112,36 +104,16 @@ class RoundKernel : virtual public CPUKernel {
|
|||
protected:
|
||||
// Generating response data of this round. The data is allocated on the heap to ensure it's not released before sent
|
||||
// back to worker.
|
||||
void GenerateOutput(const std::vector<AddressPtr> &outputs, const void *data, size_t len);
|
||||
|
||||
void GenerateOutput(const std::shared_ptr<ps::core::MessageHandler> &message, const void *data, size_t len);
|
||||
// Round kernel's name.
|
||||
std::string name_;
|
||||
|
||||
// The current received message count for this round in this iteration.
|
||||
size_t current_count_;
|
||||
|
||||
// The required received message count for this round in one iteration.
|
||||
size_t required_count_;
|
||||
|
||||
// The reason causes the error in this round kernel.
|
||||
std::string error_reason_;
|
||||
|
||||
StopTimerCb stop_timer_cb_;
|
||||
FinishIterCb finish_iteration_cb_;
|
||||
|
||||
// Members below are used for allocating and releasing response data on the heap.
|
||||
|
||||
// To ensure the performance, we use another thread to release data on the heap. So the operation on the data should
|
||||
// be threadsafe.
|
||||
std::atomic_bool running_;
|
||||
std::thread release_thread_;
|
||||
|
||||
// Data needs to be released and its mutex;
|
||||
std::mutex release_mtx_;
|
||||
std::queue<AddressPtr> heap_data_to_release_;
|
||||
std::mutex heap_data_mtx_;
|
||||
std::unordered_map<AddressPtr, std::unique_ptr<unsigned char[]>> heap_data_;
|
||||
|
||||
std::atomic<size_t> total_client_num_;
|
||||
std::atomic<size_t> accept_client_num_;
|
||||
|
||||
|
|
|
@ -89,19 +89,13 @@ sigVerifyResult ShareSecretsKernel::VerifySignature(const schema::RequestShareSe
|
|||
return sigVerifyResult::PASSED;
|
||||
}
|
||||
|
||||
bool ShareSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
bool ShareSecretsKernel::Launch(const uint8_t *req_data, size_t len,
|
||||
const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
bool response = false;
|
||||
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
||||
MS_LOG(INFO) << "Launching ShareSecretsKernel, ITERATION NUMBER IS : " << iter_num;
|
||||
|
||||
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||
std::string reason = "inputs or outputs size is invalid.";
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
}
|
||||
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
||||
void *req_data = inputs[0]->addr;
|
||||
if (fbb == nullptr || req_data == nullptr) {
|
||||
std::string reason = "FBBuilder builder or req_data is nullptr.";
|
||||
MS_LOG(ERROR) << reason;
|
||||
|
@ -112,16 +106,16 @@ bool ShareSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std
|
|||
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime,
|
||||
"Current amount for ShareSecretsKernel is enough.",
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num));
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
|
||||
flatbuffers::Verifier verifier(req_data, len);
|
||||
if (!verifier.VerifyBuffer<schema::RequestShareSecrets>()) {
|
||||
std::string reason = "The schema of RequestShareSecrets is invalid.";
|
||||
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_RequestError, reason,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num));
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
const schema::RequestShareSecrets *share_secrets_req = flatbuffers::GetRoot<schema::RequestShareSecrets>(req_data);
|
||||
|
@ -130,7 +124,7 @@ bool ShareSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std
|
|||
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_RequestError, reason,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num));
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
// verify signature
|
||||
|
@ -141,7 +135,7 @@ bool ShareSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std
|
|||
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_RequestError, reason,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num));
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -150,7 +144,7 @@ bool ShareSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std
|
|||
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime, reason,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num));
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
MS_LOG(INFO) << "verify signature passed!";
|
||||
|
@ -162,21 +156,21 @@ bool ShareSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std
|
|||
<< ". client request iteration is " << iter_client;
|
||||
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime, "ShareSecretsKernel iteration invalid",
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num));
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
response = cipher_share_->ShareSecrets(SizeToInt(iter_num), share_secrets_req, fbb,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()));
|
||||
if (!response) {
|
||||
MS_LOG(ERROR) << "update secret shares is failed.";
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
if (!CountForShareSecrets(fbb, share_secrets_req, iter_num)) {
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -38,8 +38,7 @@ class ShareSecretsKernel : public RoundKernel {
|
|||
ShareSecretsKernel() = default;
|
||||
~ShareSecretsKernel() override = default;
|
||||
void InitKernel(size_t required_cnt) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
bool Launch(const uint8_t *req_data, size_t len, const std::shared_ptr<ps::core::MessageHandler> &message) override;
|
||||
bool Reset() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -51,36 +51,29 @@ void StartFLJobKernel::InitKernel(size_t) {
|
|||
return;
|
||||
}
|
||||
|
||||
bool StartFLJobKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
bool StartFLJobKernel::Launch(const uint8_t *req_data, size_t len,
|
||||
const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
MS_LOG(DEBUG) << "Launching StartFLJobKernel kernel.";
|
||||
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||
std::string reason = "inputs or outputs size is invalid.";
|
||||
MS_LOG(ERROR) << reason;
|
||||
GenerateOutput(outputs, reason.c_str(), reason.size());
|
||||
return true;
|
||||
}
|
||||
void *req_data = inputs[0]->addr;
|
||||
std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
|
||||
if (fbb == nullptr || req_data == nullptr) {
|
||||
std::string reason = "FBBuilder builder or req_data is nullptr.";
|
||||
MS_LOG(WARNING) << reason;
|
||||
GenerateOutput(outputs, reason.c_str(), reason.size());
|
||||
GenerateOutput(message, reason.c_str(), reason.size());
|
||||
return true;
|
||||
}
|
||||
|
||||
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
|
||||
flatbuffers::Verifier verifier(req_data, len);
|
||||
if (!verifier.VerifyBuffer<schema::RequestFLJob>()) {
|
||||
std::string reason = "The schema of RequestFLJob is invalid.";
|
||||
BuildStartFLJobRsp(fbb, schema::ResponseCode_RequestError, reason, false, "");
|
||||
MS_LOG(WARNING) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
ResultCode result_code = ReachThresholdForStartFLJob(fbb);
|
||||
if (result_code != ResultCode::kSuccess) {
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return ConvertResultCode(result_code);
|
||||
}
|
||||
|
||||
|
@ -91,17 +84,17 @@ bool StartFLJobKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
|
|||
fbb, schema::ResponseCode_RequestError, reason, false,
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||
MS_LOG(WARNING) << reason;
|
||||
GenerateOutput(outputs, reason.c_str(), reason.size());
|
||||
GenerateOutput(message, reason.c_str(), reason.size());
|
||||
return true;
|
||||
}
|
||||
|
||||
if (ps::PSContext::instance()->pki_verify()) {
|
||||
if (!JudgeFLJobCert(fbb, start_fl_job_req)) {
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
if (!StoreKeyAttestation(fbb, start_fl_job_req)) {
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
@ -109,7 +102,7 @@ bool StartFLJobKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
|
|||
DeviceMeta device_meta = CreateDeviceMetadata(start_fl_job_req);
|
||||
result_code = ReadyForStartFLJob(fbb, device_meta);
|
||||
if (result_code != ResultCode::kSuccess) {
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return ConvertResultCode(result_code);
|
||||
}
|
||||
PBMetadata metadata;
|
||||
|
@ -120,7 +113,7 @@ bool StartFLJobKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
|
|||
BuildStartFLJobRsp(
|
||||
fbb, schema::ResponseCode_OutOfTime, reason, false,
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return update_reason == kNetworkError ? false : true;
|
||||
}
|
||||
|
||||
|
@ -128,11 +121,11 @@ bool StartFLJobKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
|
|||
// If calling ReportCount before ReadyForStartFLJob, the result will be inconsistent if the device is not selected.
|
||||
result_code = CountForStartFLJob(fbb, start_fl_job_req);
|
||||
if (result_code != ResultCode::kSuccess) {
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return ConvertResultCode(result_code);
|
||||
}
|
||||
IncreaseAcceptClientNum();
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -36,8 +36,7 @@ class StartFLJobKernel : public RoundKernel {
|
|||
~StartFLJobKernel() override = default;
|
||||
|
||||
void InitKernel(size_t threshold_count) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
bool Launch(const uint8_t *req_data, size_t len, const std::shared_ptr<ps::core::MessageHandler> &message) override;
|
||||
bool Reset() override;
|
||||
|
||||
void OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) override;
|
||||
|
|
|
@ -45,37 +45,29 @@ void UpdateModelKernel::InitKernel(size_t threshold_count) {
|
|||
LocalMetaStore::GetInstance().put_value(kCtxFedAvgTotalDataSize, kInitialDataSizeSum);
|
||||
}
|
||||
|
||||
bool UpdateModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
bool UpdateModelKernel::Launch(const uint8_t *req_data, size_t len,
|
||||
const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
MS_LOG(DEBUG) << "Launching UpdateModelKernel kernel.";
|
||||
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||
std::string reason = "inputs or outputs size is invalid.";
|
||||
MS_LOG(WARNING) << reason;
|
||||
GenerateOutput(outputs, reason.c_str(), reason.size());
|
||||
return true;
|
||||
}
|
||||
|
||||
void *req_data = inputs[0]->addr;
|
||||
std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
|
||||
if (fbb == nullptr || req_data == nullptr) {
|
||||
std::string reason = "FBBuilder builder or req_data is nullptr.";
|
||||
MS_LOG(WARNING) << reason;
|
||||
GenerateOutput(outputs, reason.c_str(), reason.size());
|
||||
GenerateOutput(message, reason.c_str(), reason.size());
|
||||
return true;
|
||||
}
|
||||
|
||||
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
|
||||
flatbuffers::Verifier verifier(req_data, len);
|
||||
if (!verifier.VerifyBuffer<schema::RequestUpdateModel>()) {
|
||||
std::string reason = "The schema of RequestUpdateModel is invalid.";
|
||||
BuildUpdateModelRsp(fbb, schema::ResponseCode_RequestError, reason, "");
|
||||
MS_LOG(WARNING) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
ResultCode result_code = ReachThresholdForUpdateModel(fbb);
|
||||
if (result_code != ResultCode::kSuccess) {
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return ConvertResultCode(result_code);
|
||||
}
|
||||
|
||||
|
@ -84,7 +76,7 @@ bool UpdateModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std:
|
|||
std::string reason = "Building flatbuffers schema failed for RequestUpdateModel.";
|
||||
BuildUpdateModelRsp(fbb, schema::ResponseCode_RequestError, reason, "");
|
||||
MS_LOG(WARNING) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -95,7 +87,7 @@ bool UpdateModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std:
|
|||
std::string reason = "verify signature failed.";
|
||||
BuildUpdateModelRsp(fbb, schema::ResponseCode_RequestError, reason, "");
|
||||
MS_LOG(WARNING) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -103,7 +95,7 @@ bool UpdateModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std:
|
|||
std::string reason = "verify signature timestamp failed.";
|
||||
BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason, "");
|
||||
MS_LOG(WARNING) << reason;
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
MS_LOG(INFO) << "verify signature passed!";
|
||||
|
@ -112,25 +104,25 @@ bool UpdateModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std:
|
|||
result_code = VerifyUpdateModel(update_model_req, fbb, &device_meta);
|
||||
if (result_code != ResultCode::kSuccess) {
|
||||
MS_LOG(WARNING) << "Updating model failed.";
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return ConvertResultCode(result_code);
|
||||
}
|
||||
|
||||
result_code = CountForUpdateModel(fbb, update_model_req);
|
||||
if (result_code != ResultCode::kSuccess) {
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return ConvertResultCode(result_code);
|
||||
}
|
||||
|
||||
result_code = UpdateModel(update_model_req, fbb, device_meta);
|
||||
if (result_code != ResultCode::kSuccess) {
|
||||
MS_LOG(WARNING) << "Updating model failed.";
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return ConvertResultCode(result_code);
|
||||
}
|
||||
|
||||
IncreaseAcceptClientNum();
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -44,8 +44,7 @@ class UpdateModelKernel : public RoundKernel {
|
|||
~UpdateModelKernel() override = default;
|
||||
|
||||
void InitKernel(size_t threshold_count) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
bool Launch(const uint8_t *req_data, size_t len, const std::shared_ptr<ps::core::MessageHandler> &message) override;
|
||||
bool Reset() override;
|
||||
|
||||
// In some cases, the last updateModel message means this server iteration is finished.
|
||||
|
|
|
@ -36,28 +36,21 @@ Round::Round(const std::string &name, bool check_timeout, size_t time_window, bo
|
|||
threshold_count_(threshold_count),
|
||||
server_num_as_threshold_(server_num_as_threshold) {}
|
||||
|
||||
void Round::Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &communicator, const TimeOutCb &timeout_cb,
|
||||
const FinishIterCb &finish_iteration_cb) {
|
||||
void Round::RegisterMsgCallBack(const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
|
||||
MS_EXCEPTION_IF_NULL(communicator);
|
||||
communicator_ = communicator;
|
||||
MS_LOG(INFO) << "Round " << name_ << " start initialize.";
|
||||
communicator_->RegisterMsgCallBack(name_, [&](std::shared_ptr<ps::core::MessageHandler> message) {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(message);
|
||||
LaunchRoundKernel(message);
|
||||
});
|
||||
MS_LOG(INFO) << "Round " << name_ << " register message callback.";
|
||||
communicator->RegisterMsgCallBack(
|
||||
name_, [this](std::shared_ptr<ps::core::MessageHandler> message) { LaunchRoundKernel(message); });
|
||||
}
|
||||
|
||||
void Round::Initialize(const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb) {
|
||||
MS_LOG(INFO) << "Round " << name_ << " start initialize.";
|
||||
// Callback when the iteration is finished.
|
||||
finish_iteration_cb_ = [this, finish_iteration_cb](bool, const std::string &) -> void {
|
||||
std::string reason = "Round " + name_ + " finished! This iteration is valid. Proceed to next iteration.";
|
||||
finish_iteration_cb(true, reason);
|
||||
};
|
||||
|
||||
// Callback for finalizing the server. This can only be called once.
|
||||
finalize_cb_ = [&](void) -> void {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(communicator_);
|
||||
(void)communicator_->Stop();
|
||||
};
|
||||
|
||||
if (check_timeout_) {
|
||||
iter_timer_ = std::make_shared<IterationTimer>();
|
||||
MS_EXCEPTION_IF_NULL(iter_timer_);
|
||||
|
@ -69,7 +62,7 @@ void Round::Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &commun
|
|||
});
|
||||
|
||||
// 2.Stopping timer callback which will be set to the round kernel.
|
||||
stop_timer_cb_ = [&](void) -> void {
|
||||
stop_timer_cb_ = [this](void) -> void {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(iter_timer_);
|
||||
MS_LOG(INFO) << "Round " << name_ << " kernel stops its timer.";
|
||||
iter_timer_->Stop();
|
||||
|
@ -129,40 +122,17 @@ void Round::BindRoundKernel(const std::shared_ptr<kernel::RoundKernel> &kernel)
|
|||
void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(message);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(kernel_);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(communicator_);
|
||||
|
||||
std::string reason = "";
|
||||
if (!IsServerAvailable(&reason)) {
|
||||
if (!communicator_->SendResponse(reason.c_str(), reason.size(), message)) {
|
||||
if (!message->SendResponse(reason.c_str(), reason.size())) {
|
||||
MS_LOG(WARNING) << "Sending response failed.";
|
||||
return;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
(void)(Iteration::GetInstance().running_round_num_++);
|
||||
AddressPtr input = std::make_shared<Address>();
|
||||
AddressPtr output = std::make_shared<Address>();
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(input);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(output);
|
||||
input->addr = message->data();
|
||||
input->size = message->len();
|
||||
bool ret = kernel_->Launch({input}, {}, {output});
|
||||
if (output->size == 0) {
|
||||
reason = "The output of the round " + name_ + " is empty.";
|
||||
MS_LOG(WARNING) << reason;
|
||||
if (!communicator_->SendResponse(reason.c_str(), reason.size(), message)) {
|
||||
MS_LOG(ERROR) << "Sending response failed.";
|
||||
return;
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (!communicator_->SendResponse(output->addr, output->size, message)) {
|
||||
MS_LOG(ERROR) << "Sending response failed.";
|
||||
return;
|
||||
}
|
||||
kernel_->Release(output);
|
||||
|
||||
bool ret = kernel_->Launch(reinterpret_cast<const uint8_t *>(message->data()), message->len(), message);
|
||||
// Must send response back no matter what value Launch method returns.
|
||||
if (!ret) {
|
||||
reason = "Launching round kernel of round " + name_ + " failed.";
|
||||
|
|
|
@ -37,8 +37,8 @@ class Round {
|
|||
bool check_count = false, size_t threshold_count = 8, bool server_num_as_threshold = false);
|
||||
~Round() = default;
|
||||
|
||||
void Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &communicator, const TimeOutCb &timeout_cb,
|
||||
const FinishIterCb &finish_iteration_cb);
|
||||
void RegisterMsgCallBack(const std::shared_ptr<ps::core::CommunicatorBase> &communicator);
|
||||
void Initialize(const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb);
|
||||
|
||||
// Reinitialize count service and round kernel of this round after scaling operations are done.
|
||||
bool ReInitForScaling(uint32_t server_num);
|
||||
|
@ -106,8 +106,6 @@ class Round {
|
|||
// Whether this round uses the server number as its threshold count.
|
||||
bool server_num_as_threshold_;
|
||||
|
||||
std::shared_ptr<ps::core::CommunicatorBase> communicator_;
|
||||
|
||||
// The round kernel for this Round.
|
||||
std::shared_ptr<kernel::RoundKernel> kernel_;
|
||||
|
||||
|
@ -117,7 +115,6 @@ class Round {
|
|||
// The callbacks which will be set to the round kernel.
|
||||
StopTimerCb stop_timer_cb_;
|
||||
FinishIterCb finish_iteration_cb_;
|
||||
FinalizeCb finalize_cb_;
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace fl
|
||||
|
|
|
@ -280,7 +280,7 @@ bool AbstractNode::Send(const NodeRole &node_role, const uint32_t &rank_id, cons
|
|||
message_meta->set_role(node_info_.node_role_);
|
||||
message_meta->set_user_cmd(command);
|
||||
|
||||
auto client = GetOrCreateTcpClient(rank_id);
|
||||
auto client = GetOrCreateTcpClient(rank_id, node_role);
|
||||
MS_EXCEPTION_IF_NULL(client);
|
||||
if (!client->SendMessage(message_meta, Protos::RAW, message.get(), len)) {
|
||||
MS_LOG(WARNING) << "Client send message failed.";
|
||||
|
@ -330,7 +330,7 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
|
|||
auto send = data.at(it);
|
||||
auto len = data_lens.at(it);
|
||||
|
||||
auto client = GetOrCreateTcpClient(rank_ids.at(it));
|
||||
auto client = GetOrCreateTcpClient(rank_ids.at(it), node_role);
|
||||
MS_EXCEPTION_IF_NULL(client);
|
||||
if (!client->SendMessage(message_meta, Protos::RAW, send.get(), len)) {
|
||||
MS_LOG(WARNING) << "Client send message failed.";
|
||||
|
|
|
@ -253,7 +253,7 @@ void HttpMessageHandler::SendResponse() {
|
|||
evhttp_send_reply(event_request_, resp_code_, "Client", resp_buf_);
|
||||
}
|
||||
|
||||
void HttpMessageHandler::QuickResponse(int code, const unsigned char *body, size_t len) {
|
||||
void HttpMessageHandler::QuickResponse(int code, const void *body, size_t len) {
|
||||
MS_EXCEPTION_IF_NULL(event_request_);
|
||||
MS_EXCEPTION_IF_NULL(body);
|
||||
MS_EXCEPTION_IF_NULL(resp_buf_);
|
||||
|
|
|
@ -91,7 +91,7 @@ class HttpMessageHandler {
|
|||
|
||||
// Make sure code and all response body has finished set
|
||||
void SendResponse();
|
||||
void QuickResponse(int code, const unsigned char *body, size_t len);
|
||||
void QuickResponse(int code, const void *body, size_t len);
|
||||
void SimpleResponse(int code, const HttpHeaders &headers, const std::string &body);
|
||||
void ErrorResponse(int code, const RequestProcessResult &status);
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ size_t HttpMsgHandler::len() const { return len_; }
|
|||
|
||||
bool HttpMsgHandler::SendResponse(const void *data, const size_t &len) {
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(data, false);
|
||||
http_msg_->QuickResponse(kHttpSuccess, reinterpret_cast<unsigned char *>(const_cast<void *>(data)), len);
|
||||
http_msg_->QuickResponse(kHttpSuccess, data, len);
|
||||
return true;
|
||||
}
|
||||
} // namespace core
|
||||
|
|
Loading…
Reference in New Issue