!30058 FL, opt kernel launch

Merge pull request !30058 from 徐永飞/r1.6
This commit is contained in:
i-robot 2022-02-15 09:03:04 +00:00 committed by Gitee
commit 4b1df4d5f5
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
38 changed files with 205 additions and 423 deletions

View File

@ -72,21 +72,18 @@ void Iteration::InitRounds(const std::vector<std::shared_ptr<ps::core::Communica
MS_LOG(EXCEPTION) << "Communicators for rounds is empty.";
return;
}
(void)std::for_each(communicators.begin(), communicators.end(),
[&](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
for (auto &round : rounds_) {
MS_EXCEPTION_IF_NULL(round);
round->Initialize(communicator, timeout_cb, finish_iteration_cb);
}
});
// The time window for one iteration, which will be used in some round kernels.
size_t iteration_time_window = std::accumulate(rounds_.begin(), rounds_.end(), IntToSize(0),
[](size_t total, const std::shared_ptr<Round> &round) {
MS_EXCEPTION_IF_NULL(round);
return round->check_timeout() ? total + round->time_window() : total;
});
size_t iteration_time_window = 0;
for (auto &round : rounds_) {
MS_EXCEPTION_IF_NULL(round);
round->Initialize(timeout_cb, finish_iteration_cb);
for (auto &communicator : communicators) {
round->RegisterMsgCallBack(communicator);
}
if (round->check_timeout()) {
iteration_time_window += round->time_window();
}
}
LocalMetaStore::GetInstance().put_value(kCtxTotalTimeoutDuration, iteration_time_window);
MS_LOG(INFO) << "Time window for one iteration is " << iteration_time_window;

View File

@ -160,7 +160,7 @@ class FedAvgKernel : public AggregationKernel {
bool IsAggregationDone() override { return done_; }
void SetParameterAddress(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
const std::vector<AddressPtr> &outputs) override {
weight_addr_ = inputs[0];
data_size_addr_ = inputs[1];
new_weight_addr_ = inputs[2];

View File

@ -131,31 +131,25 @@ bool ClientListKernel::DealClient(const size_t iter_num, const schema::GetClient
return true;
}
bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
bool ClientListKernel::Launch(const uint8_t *req_data, size_t len,
const std::shared_ptr<ps::core::MessageHandler> &message) {
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
MS_LOG(INFO) << "Launching ClientListKernel, Iteration number is " << iter_num;
if (inputs.size() != 1 || outputs.size() != 1) {
std::string reason = "inputs or outputs size is invalid.";
MS_LOG(ERROR) << reason;
return false;
}
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
void *req_data = inputs[0]->addr;
if (fbb == nullptr || req_data == nullptr) {
std::string reason = "FBBuilder builder or req_data is nullptr.";
MS_LOG(ERROR) << reason;
return false;
}
std::vector<string> client_list;
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
flatbuffers::Verifier verifier(req_data, len);
if (!verifier.VerifyBuffer<schema::GetClientList>()) {
std::string reason = "The schema of GetClientList is invalid.";
BuildClientListRsp(fbb, schema::ResponseCode_RequestError, reason, client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
const schema::GetClientList *get_clients_req = flatbuffers::GetRoot<schema::GetClientList>(req_data);
@ -164,7 +158,7 @@ bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
BuildClientListRsp(fbb, schema::ResponseCode_RequestError, reason, client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
// verify signature
@ -175,7 +169,7 @@ bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
BuildClientListRsp(fbb, schema::ResponseCode_RequestError, reason, client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
if (verify_result == sigVerifyResult::TIMEOUT) {
@ -183,7 +177,7 @@ bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, reason, client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
MS_LOG(DEBUG) << "verify signature passed!";
@ -195,7 +189,7 @@ bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
<< ". client request iteration is " << iter_client;
BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, "iter num is error.", client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
@ -206,7 +200,7 @@ bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
if (!DealClient(iter_num, get_clients_req, fbb)) {
MS_LOG(WARNING) << "Get Client List not ready.";
}
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
} // namespace fl

View File

@ -37,8 +37,7 @@ class ClientListKernel : public RoundKernel {
ClientListKernel() = default;
~ClientListKernel() override = default;
void InitKernel(size_t required_cnt) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override;
bool Launch(const uint8_t *req_data, size_t len, const std::shared_ptr<ps::core::MessageHandler> &message) override;
bool Reset() override;
void BuildClientListRsp(const std::shared_ptr<server::FBBuilder> &fbb, const schema::ResponseCode retcode,
const string &reason, std::vector<std::string> clients, const string &next_req_time,

View File

@ -123,18 +123,12 @@ sigVerifyResult ExchangeKeysKernel::VerifySignature(const schema::RequestExchang
return sigVerifyResult::PASSED;
}
bool ExchangeKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
bool ExchangeKeysKernel::Launch(const uint8_t *req_data, size_t len,
const std::shared_ptr<ps::core::MessageHandler> &message) {
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
MS_LOG(INFO) << "Launching ExchangeKey kernel, ITERATION NUMBER IS : " << iter_num;
bool response = false;
if (inputs.size() != 1 || outputs.size() != 1) {
std::string reason = "inputs or outputs size is invalid.";
MS_LOG(ERROR) << reason;
return false;
}
void *req_data = inputs[0]->addr;
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
if (fbb == nullptr || req_data == nullptr) {
std::string reason = "FBBuilder builder or req_data is nullptr.";
@ -143,17 +137,17 @@ bool ExchangeKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std
}
if (ReachThresholdForExchangeKeys(fbb, iter_num)) {
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
flatbuffers::Verifier verifier(req_data, len);
if (!verifier.VerifyBuffer<schema::RequestExchangeKeys>()) {
std::string reason = "The schema of RequestExchangeKeys is invalid.";
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_RequestError, reason,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
const schema::RequestExchangeKeys *exchange_keys_req = flatbuffers::GetRoot<schema::RequestExchangeKeys>(req_data);
@ -162,7 +156,7 @@ bool ExchangeKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_RequestError, reason,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
@ -174,7 +168,7 @@ bool ExchangeKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_RequestError, reason,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
@ -183,7 +177,7 @@ bool ExchangeKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
@ -196,21 +190,21 @@ bool ExchangeKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std
<< ". client request iteration is " << iter_client;
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_OutOfTime, "iter num is error.",
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
response = cipher_key_->ExchangeKeys(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), exchange_keys_req, fbb);
if (!response) {
MS_LOG(ERROR) << "update exchange keys is failed.";
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
if (!CountForExchangeKeys(fbb, exchange_keys_req, iter_num)) {
MS_LOG(ERROR) << "count for exchange keys failed.";
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}

View File

@ -39,8 +39,7 @@ class ExchangeKeysKernel : public RoundKernel {
ExchangeKeysKernel() = default;
~ExchangeKeysKernel() override = default;
void InitKernel(size_t required_cnt) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override;
bool Launch(const uint8_t *req_data, size_t len, const std::shared_ptr<ps::core::MessageHandler> &message) override;
bool Reset() override;
private:

View File

@ -90,19 +90,12 @@ sigVerifyResult GetKeysKernel::VerifySignature(const schema::GetExchangeKeys *ge
return sigVerifyResult::PASSED;
}
bool GetKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
bool GetKeysKernel::Launch(const uint8_t *req_data, size_t len,
const std::shared_ptr<ps::core::MessageHandler> &message) {
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
MS_LOG(INFO) << "Launching GetKeys kernel, ITERATION NUMBER IS : " << iter_num;
bool response = false;
if (inputs.size() != 1 || outputs.size() != 1) {
std::string reason = "inputs or outputs size is invalid.";
MS_LOG(ERROR) << reason;
return false;
}
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
void *req_data = inputs[0]->addr;
if (fbb == nullptr || req_data == nullptr) {
std::string reason = "FBBuilder builder or req_data is nullptr.";
MS_LOG(ERROR) << reason;
@ -111,13 +104,13 @@ bool GetKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vec
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
MS_LOG(WARNING) << "Current amount for GetKeysKernel is enough.";
}
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
flatbuffers::Verifier verifier(req_data, len);
if (!verifier.VerifyBuffer<schema::GetExchangeKeys>()) {
std::string reason = "The schema of GetExchangeKeys is invalid.";
cipher_key_->BuildGetKeysRsp(fbb, schema::ResponseCode_RequestError, iter_num,
std::to_string(CURRENT_TIME_MILLI.count()), false);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
const schema::GetExchangeKeys *get_exchange_keys_req = flatbuffers::GetRoot<schema::GetExchangeKeys>(req_data);
@ -126,7 +119,7 @@ bool GetKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vec
cipher_key_->BuildGetKeysRsp(fbb, schema::ResponseCode_RequestError, iter_num,
std::to_string(CURRENT_TIME_MILLI.count()), false);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
@ -138,7 +131,7 @@ bool GetKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vec
cipher_key_->BuildGetKeysRsp(fbb, schema::ResponseCode_RequestError, iter_num,
std::to_string(CURRENT_TIME_MILLI.count()), false);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
@ -147,7 +140,7 @@ bool GetKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vec
cipher_key_->BuildGetKeysRsp(fbb, schema::ResponseCode_OutOfTime, iter_num,
std::to_string(CURRENT_TIME_MILLI.count()), false);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
MS_LOG(INFO) << "verify signature passed!";
@ -159,20 +152,20 @@ bool GetKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vec
<< ". client request iteration is " << iter_client;
cipher_key_->BuildGetKeysRsp(fbb, schema::ResponseCode_OutOfTime, iter_num,
std::to_string(CURRENT_TIME_MILLI.count()), false);
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
response = cipher_key_->GetKeys(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), get_exchange_keys_req, fbb);
if (!response) {
MS_LOG(WARNING) << "get public keys not ready.";
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
if (!CountForGetKeys(fbb, get_exchange_keys_req, iter_num)) {
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
GenerateOutput(outputs, fbb->GetCurrentBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetCurrentBufferPointer(), fbb->GetSize());
return true;
}

View File

@ -38,8 +38,7 @@ class GetKeysKernel : public RoundKernel {
GetKeysKernel() = default;
~GetKeysKernel() override = default;
void InitKernel(size_t required_cnt) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override;
bool Launch(const uint8_t *req_data, size_t len, const std::shared_ptr<ps::core::MessageHandler> &message) override;
bool Reset() override;
private:

View File

@ -78,30 +78,24 @@ sigVerifyResult GetListSignKernel::VerifySignature(const schema::RequestAllClien
return sigVerifyResult::PASSED;
}
bool GetListSignKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
bool GetListSignKernel::Launch(const uint8_t *req_data, size_t len,
const std::shared_ptr<ps::core::MessageHandler> &message) {
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
MS_LOG(INFO) << "Launching GetListSign kernel, Iteration number is " << iter_num;
if (inputs.size() != 1 || outputs.size() != 1) {
std::string reason = "inputs or outputs size is invalid.";
MS_LOG(ERROR) << reason;
return false;
}
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
void *req_data = inputs[0]->addr;
if (fbb == nullptr || req_data == nullptr) {
std::string reason = "FBBuilder builder or req_data is nullptr.";
MS_LOG(ERROR) << reason;
return false;
}
std::map<std::string, std::vector<unsigned char>> list_signs;
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
flatbuffers::Verifier verifier(req_data, len);
if (!verifier.VerifyBuffer<schema::RequestAllClientListSign>()) {
std::string reason = "The schema of RequestAllClientListSign is invalid.";
BuildGetListSignKernelRsp(fbb, schema::ResponseCode_RequestError, reason,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num, list_signs);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
const schema::RequestAllClientListSign *get_list_sign_req =
@ -111,7 +105,7 @@ bool GetListSignKernel::Launch(const std::vector<AddressPtr> &inputs, const std:
BuildGetListSignKernelRsp(fbb, schema::ResponseCode_RequestError, reason,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num, list_signs);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
@ -123,7 +117,7 @@ bool GetListSignKernel::Launch(const std::vector<AddressPtr> &inputs, const std:
BuildGetListSignKernelRsp(fbb, schema::ResponseCode_RequestError, reason,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num, list_signs);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
@ -132,7 +126,7 @@ bool GetListSignKernel::Launch(const std::vector<AddressPtr> &inputs, const std:
BuildGetListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(CURRENT_TIME_MILLI.count()),
iter_num, list_signs);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
@ -147,7 +141,7 @@ bool GetListSignKernel::Launch(const std::vector<AddressPtr> &inputs, const std:
<< ". client request iteration is " << iter_client;
BuildGetListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, "iter num is error.",
std::to_string(CURRENT_TIME_MILLI.count()), iter_num, list_signs);
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
std::string fl_id = get_list_sign_req->fl_id()->str();
@ -156,7 +150,7 @@ bool GetListSignKernel::Launch(const std::vector<AddressPtr> &inputs, const std:
}
if (!GetListSign(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), get_list_sign_req, fbb)) {
MS_LOG(WARNING) << "get list signs not ready.";
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
std::string count_reason = "";
@ -167,7 +161,7 @@ bool GetListSignKernel::Launch(const std::vector<AddressPtr> &inputs, const std:
MS_LOG(ERROR) << reason;
return true;
}
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}

View File

@ -38,8 +38,7 @@ class GetListSignKernel : public RoundKernel {
GetListSignKernel() = default;
~GetListSignKernel() override = default;
void InitKernel(size_t required_cnt) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override;
bool Launch(const uint8_t *req_data, size_t len, const std::shared_ptr<ps::core::MessageHandler> &message) override;
bool Reset() override;
void BuildGetListSignKernelRsp(const std::shared_ptr<server::FBBuilder> &fbb, const schema::ResponseCode retcode,
const string &reason, const string &next_req_time, const size_t iteration,

View File

@ -39,31 +39,23 @@ void GetModelKernel::InitKernel(size_t) {
}
}
bool GetModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
if (inputs.size() != 1 || outputs.size() != 1) {
std::string reason = "inputs or outputs size is invalid.";
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, reason.c_str(), reason.size());
return true;
}
void *req_data = inputs[0]->addr;
bool GetModelKernel::Launch(const uint8_t *req_data, size_t len,
const std::shared_ptr<ps::core::MessageHandler> &message) {
std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
if (fbb == nullptr || req_data == nullptr) {
std::string reason = "FBBuilder builder or req_data is nullptr.";
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, reason.c_str(), reason.size());
GenerateOutput(message, reason.c_str(), reason.size());
return true;
}
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
flatbuffers::Verifier verifier(req_data, len);
if (!verifier.VerifyBuffer<schema::RequestGetModel>()) {
std::string reason = "The schema of RequestGetModel is invalid.";
BuildGetModelRsp(fbb, schema::ResponseCode_RequestError, reason, LocalMetaStore::GetInstance().curr_iter_num(), {},
"");
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
@ -76,11 +68,11 @@ bool GetModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std::ve
if (get_model_req == nullptr) {
std::string reason = "Building flatbuffers schema failed for RequestGetModel.";
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, reason.c_str(), reason.size());
GenerateOutput(message, reason.c_str(), reason.size());
return true;
}
GetModel(get_model_req, fbb);
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}

View File

@ -37,8 +37,7 @@ class GetModelKernel : public RoundKernel {
~GetModelKernel() override = default;
void InitKernel(size_t) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
bool Launch(const uint8_t *req_data, size_t len, const std::shared_ptr<ps::core::MessageHandler> &message) override;
bool Reset() override;
private:

View File

@ -91,31 +91,23 @@ sigVerifyResult GetSecretsKernel::VerifySignature(const schema::GetShareSecrets
return sigVerifyResult::PASSED;
}
bool GetSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
bool GetSecretsKernel::Launch(const uint8_t *req_data, size_t len,
const std::shared_ptr<ps::core::MessageHandler> &message) {
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
std::string next_timestamp = std::to_string(CURRENT_TIME_MILLI.count());
MS_LOG(INFO) << "Launching get secrets kernel, ITERATION NUMBER IS : " << iter_num;
if (inputs.size() != 1 || outputs.size() != 1) {
std::string reason = "inputs or outputs size is invalid.";
MS_LOG(ERROR) << reason;
return false;
}
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
void *req_data = inputs[0]->addr;
if (fbb == nullptr || req_data == nullptr) {
std::string reason = "FBBuilder builder or req_data is nullptr.";
MS_LOG(ERROR) << reason;
return false;
}
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
flatbuffers::Verifier verifier(req_data, len);
if (!verifier.VerifyBuffer<schema::GetShareSecrets>()) {
std::string reason = "The schema of GetShareSecrets is invalid.";
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_RequestError, iter_num, next_timestamp, nullptr);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
const schema::GetShareSecrets *get_secrets_req = flatbuffers::GetRoot<schema::GetShareSecrets>(req_data);
@ -123,7 +115,7 @@ bool GetSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
std::string reason = "Building flatbuffers schema failed for GetExchangeKeys.";
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_RequestError, iter_num, next_timestamp, nullptr);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
@ -134,7 +126,7 @@ bool GetSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
std::string reason = "verify signature failed.";
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_RequestError, iter_num, next_timestamp, nullptr);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
@ -142,7 +134,7 @@ bool GetSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
std::string reason = "verify signature timestamp failed.";
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_OutOfTime, iter_num, next_timestamp, nullptr);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
@ -155,7 +147,7 @@ bool GetSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
MS_LOG(ERROR) << "GetSecretsKernel iteration invalid. server now iteration is " << iter_num
<< ". client request iteration is " << iter_client;
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_OutOfTime, iter_num, next_timestamp, nullptr);
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
@ -166,14 +158,14 @@ bool GetSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
bool response = cipher_share_->GetSecrets(get_secrets_req, fbb, next_timestamp);
if (!response) {
MS_LOG(WARNING) << "get secret shares not ready.";
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
if (!CountForGetSecrets(fbb, get_secrets_req, iter_num)) {
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}

View File

@ -37,8 +37,7 @@ class GetSecretsKernel : public RoundKernel {
GetSecretsKernel() = default;
~GetSecretsKernel() override = default;
void InitKernel(size_t required_cnt) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override;
bool Launch(const uint8_t *req_data, size_t len, const std::shared_ptr<ps::core::MessageHandler> &message) override;
bool Reset() override;
private:

View File

@ -34,17 +34,9 @@ void PullWeightKernel::InitKernel(size_t) {
}
}
bool PullWeightKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
bool PullWeightKernel::Launch(const uint8_t *req_data, size_t len,
const std::shared_ptr<ps::core::MessageHandler> &message) {
MS_LOG(DEBUG) << "Launching PullWeightKernel kernel.";
if (inputs.size() != 1 || outputs.size() != 1) {
std::string reason = "inputs or outputs size is invalid.";
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, reason.c_str(), reason.size());
return true;
}
void *req_data = inputs[0]->addr;
std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
if (fbb == nullptr || req_data == nullptr) {
MS_LOG(ERROR) << "FBBuilder builder or req_data is nullptr.";
@ -56,12 +48,12 @@ bool PullWeightKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
std::string reason = "Building flatbuffers schema failed for RequestPullWeight";
BuildPullWeightRsp(fbb, schema::ResponseCode_RequestError, reason, LocalMetaStore::GetInstance().curr_iter_num(),
{});
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return false;
}
PullWeight(fbb, pull_weight_req);
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}

View File

@ -37,8 +37,7 @@ class PullWeightKernel : public RoundKernel {
~PullWeightKernel() override = default;
void InitKernel(size_t required_cnt) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
bool Launch(const uint8_t *req_data, size_t len, const std::shared_ptr<ps::core::MessageHandler> &message) override;
bool Reset() override;
private:

View File

@ -33,29 +33,23 @@ void PushListSignKernel::InitKernel(size_t) {
cipher_init_ = &armour::CipherInit::GetInstance();
}
bool PushListSignKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
bool PushListSignKernel::Launch(const uint8_t *req_data, size_t len,
const std::shared_ptr<ps::core::MessageHandler> &message) {
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
MS_LOG(INFO) << "Launching PushListSignKernel, Iteration number is " << iter_num;
if (inputs.size() != 1 || outputs.size() != 1) {
std::string reason = "inputs or outputs size is invalid.";
MS_LOG(ERROR) << reason;
return false;
}
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
void *req_data = inputs[0]->addr;
if (fbb == nullptr || req_data == nullptr) {
std::string reason = "FBBuilder builder or req_data is nullptr.";
MS_LOG(ERROR) << reason;
return false;
}
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
flatbuffers::Verifier verifier(req_data, len);
if (!verifier.VerifyBuffer<schema::SendClientListSign>()) {
std::string reason = "The schema of PushClientListSign is invalid.";
BuildPushListSignKernelRsp(fbb, schema::ResponseCode_RequestError, reason,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
const schema::SendClientListSign *client_list_sign_req = flatbuffers::GetRoot<schema::SendClientListSign>(req_data);
@ -64,7 +58,7 @@ bool PushListSignKernel::Launch(const std::vector<AddressPtr> &inputs, const std
BuildPushListSignKernelRsp(fbb, schema::ResponseCode_RequestError, reason,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
// verify signature
@ -75,7 +69,7 @@ bool PushListSignKernel::Launch(const std::vector<AddressPtr> &inputs, const std
BuildPushListSignKernelRsp(fbb, schema::ResponseCode_RequestError, reason,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
if (verify_result == sigVerifyResult::TIMEOUT) {
@ -83,17 +77,17 @@ bool PushListSignKernel::Launch(const std::vector<AddressPtr> &inputs, const std
BuildPushListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
MS_LOG(INFO) << "verify signature passed!";
}
return LaunchForPushListSign(client_list_sign_req, iter_num, fbb, outputs);
return LaunchForPushListSign(client_list_sign_req, iter_num, fbb, message);
}
bool PushListSignKernel::LaunchForPushListSign(const schema::SendClientListSign *client_list_sign_req,
const size_t &iter_num, const std::shared_ptr<server::FBBuilder> &fbb,
const std::vector<AddressPtr> &outputs) {
const std::shared_ptr<ps::core::MessageHandler> &message) {
MS_ERROR_IF_NULL_W_RET_VAL(client_list_sign_req, false);
size_t iter_client = IntToSize(client_list_sign_req->iteration());
if (iter_num != iter_client) {
@ -102,7 +96,7 @@ bool PushListSignKernel::LaunchForPushListSign(const schema::SendClientListSign
MS_LOG(WARNING) << "server now iteration is " << iter_num << ". client request iteration is " << iter_client;
BuildPushListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(CURRENT_TIME_MILLI.count()),
iter_num);
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
std::vector<string> update_model_clients;
@ -125,13 +119,13 @@ bool PushListSignKernel::LaunchForPushListSign(const schema::SendClientListSign
"Current amount for PushListSignKernel is enough.",
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
}
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
if (!PushListSign(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), client_list_sign_req, fbb,
update_model_clients)) {
MS_LOG(ERROR) << "push client list sign failed.";
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
std::string count_reason = "";
@ -140,9 +134,10 @@ bool PushListSignKernel::LaunchForPushListSign(const schema::SendClientListSign
BuildPushListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(CURRENT_TIME_MILLI.count()),
iter_num);
MS_LOG(ERROR) << reason;
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}

View File

@ -38,10 +38,10 @@ class PushListSignKernel : public RoundKernel {
PushListSignKernel() = default;
~PushListSignKernel() override = default;
void InitKernel(size_t required_cnt) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override;
bool Launch(const uint8_t *req_data, size_t len, const std::shared_ptr<ps::core::MessageHandler> &message) override;
bool LaunchForPushListSign(const schema::SendClientListSign *client_list_sign_req, const size_t &iter_num,
const std::shared_ptr<server::FBBuilder> &fbb, const std::vector<AddressPtr> &outputs);
const std::shared_ptr<server::FBBuilder> &fbb,
const std::shared_ptr<ps::core::MessageHandler> &message);
bool Reset() override;
void BuildPushListSignKernelRsp(const std::shared_ptr<server::FBBuilder> &fbb, const schema::ResponseCode retcode,
const string &reason, const string &next_req_time, const size_t iteration);

View File

@ -24,24 +24,23 @@ namespace server {
namespace kernel {
void PushMetricsKernel::InitKernel(size_t) { local_rank_ = DistributedCountService::GetInstance().local_rank(); }
bool PushMetricsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
bool PushMetricsKernel::Launch(const uint8_t *req_data, size_t len,
const std::shared_ptr<ps::core::MessageHandler> &message) {
MS_LOG(INFO) << "Launching PushMetricsKernel kernel.";
void *req_data = inputs[0]->addr;
std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
if (fbb == nullptr || req_data == nullptr) {
std::string reason = "FBBuilder builder or req_data is nullptr.";
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, reason.c_str(), reason.size());
GenerateOutput(message, reason.c_str(), reason.size());
return true;
}
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
flatbuffers::Verifier verifier(req_data, len);
if (!verifier.VerifyBuffer<schema::RequestPushMetrics>()) {
std::string reason = "The schema of RequestPushMetrics is invalid.";
BuildPushMetricsRsp(fbb, schema::ResponseCode_RequestError);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
@ -50,12 +49,12 @@ bool PushMetricsKernel::Launch(const std::vector<AddressPtr> &inputs, const std:
std::string reason = "Building flatbuffers schema failed for RequestPushMetrics";
BuildPushMetricsRsp(fbb, schema::ResponseCode_RequestError);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return false;
}
ResultCode result_code = PushMetrics(fbb, push_metrics_req);
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return ConvertResultCode(result_code);
}

View File

@ -36,8 +36,7 @@ class PushMetricsKernel : public RoundKernel {
~PushMetricsKernel() override = default;
void InitKernel(size_t threshold_count);
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
bool Launch(const uint8_t *req_data, size_t len, const std::shared_ptr<ps::core::MessageHandler> &message) override;
bool Reset() override;
void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) override;

View File

@ -30,31 +30,23 @@ void PushWeightKernel::InitKernel(size_t) {
local_rank_ = DistributedCountService::GetInstance().local_rank();
}
bool PushWeightKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
bool PushWeightKernel::Launch(const uint8_t *req_data, size_t len,
const std::shared_ptr<ps::core::MessageHandler> &message) {
MS_LOG(INFO) << "Launching PushWeightKernel kernel.";
if (inputs.size() != 1 || outputs.size() != 1) {
std::string reason = "inputs or outputs size is invalid.";
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, reason.c_str(), reason.size());
return true;
}
void *req_data = inputs[0]->addr;
std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
if (fbb == nullptr || req_data == nullptr) {
std::string reason = "FBBuilder builder or req_data is nullptr.";
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, reason.c_str(), reason.size());
GenerateOutput(message, reason.c_str(), reason.size());
return true;
}
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
flatbuffers::Verifier verifier(req_data, len);
if (!verifier.VerifyBuffer<schema::RequestPushWeight>()) {
std::string reason = "The schema of RequestPushWeight is invalid.";
BuildPushWeightRsp(fbb, schema::ResponseCode_RequestError, reason, LocalMetaStore::GetInstance().curr_iter_num());
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
@ -62,12 +54,12 @@ bool PushWeightKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
if (push_weight_req == nullptr) {
std::string reason = "Building flatbuffers schema failed for RequestPushWeight";
BuildPushWeightRsp(fbb, schema::ResponseCode_RequestError, reason, LocalMetaStore::GetInstance().curr_iter_num());
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return false;
}
ResultCode result_code = PushWeight(fbb, push_weight_req);
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return ConvertResultCode(result_code);
}

View File

@ -36,8 +36,7 @@ class PushWeightKernel : public RoundKernel {
~PushWeightKernel() override = default;
void InitKernel(size_t threshold_count) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
bool Launch(const uint8_t *req_data, size_t len, const std::shared_ptr<ps::core::MessageHandler> &message) override;
bool Reset() override;
void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) override;

View File

@ -90,20 +90,13 @@ sigVerifyResult ReconstructSecretsKernel::VerifySignature(const schema::SendReco
return sigVerifyResult::PASSED;
}
bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
bool ReconstructSecretsKernel::Launch(const uint8_t *req_data, size_t len,
const std::shared_ptr<ps::core::MessageHandler> &message) {
bool response = false;
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
MS_LOG(INFO) << "Launching ReconstructSecrets Kernel, Iteration number is " << iter_num;
if (inputs.size() != 1 || outputs.size() != 1) {
MS_LOG(ERROR) << "ReconstructSecretsKernel needs 1 input, but got " << inputs.size();
return false;
}
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
void *req_data = inputs[0]->addr;
if (fbb == nullptr || req_data == nullptr) {
std::string reason = "FBBuilder builder or req_data is nullptr.";
MS_LOG(ERROR) << reason;
@ -119,13 +112,13 @@ bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, con
for (size_t i = 0; i < IntToSize(update_model_clients_pb.fl_id_size()); ++i) {
update_model_clients.push_back(update_model_clients_pb.fl_id(SizeToInt(i)));
}
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
flatbuffers::Verifier verifier(req_data, len);
if (!verifier.VerifyBuffer<schema::SendReconstructSecret>()) {
std::string reason = "The schema of SendReconstructSecret is invalid.";
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_RequestError, reason, SizeToInt(iter_num),
std::to_string(CURRENT_TIME_MILLI.count()));
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
const schema::SendReconstructSecret *reconstruct_secret_req =
@ -135,7 +128,7 @@ bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, con
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_RequestError, reason, SizeToInt(iter_num),
std::to_string(CURRENT_TIME_MILLI.count()));
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
// verify signature
@ -146,7 +139,7 @@ bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, con
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_RequestError, reason,
SizeToInt(iter_num), std::to_string(CURRENT_TIME_MILLI.count()));
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
@ -155,7 +148,7 @@ bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, con
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_OutOfTime, reason, SizeToInt(iter_num),
std::to_string(CURRENT_TIME_MILLI.count()));
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
MS_LOG(INFO) << "verify signature passed!";
@ -174,7 +167,7 @@ bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, con
"Current amount for ReconstructSecretsKernel is enough.",
SizeToInt(iter_num), std::to_string(CURRENT_TIME_MILLI.count()));
}
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
@ -186,7 +179,7 @@ bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, con
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
MS_LOG(INFO) << "Current amount for ReconstructSecretsKernel is enough.";
}
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
MS_LOG(INFO) << "reconstruct_secrets_kernel success.";
if (!response) {

View File

@ -39,8 +39,7 @@ class ReconstructSecretsKernel : public RoundKernel {
~ReconstructSecretsKernel() override = default;
void InitKernel(size_t required_cnt) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override;
bool Launch(const uint8_t *req_data, size_t len, const std::shared_ptr<ps::core::MessageHandler> &message) override;
bool Reset() override;
void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) override;

View File

@ -27,39 +27,9 @@ namespace mindspore {
namespace fl {
namespace server {
namespace kernel {
RoundKernel::RoundKernel() : name_(""), current_count_(0), required_count_(0), error_reason_(""), running_(true) {
release_thread_ = std::thread([&]() {
while (running_.load()) {
std::unique_lock<std::mutex> release_lock(release_mtx_);
// Detect whether there's any data needs to be released every 100 milliseconds.
if (heap_data_to_release_.empty()) {
release_lock.unlock();
std::this_thread::sleep_for(std::chrono::milliseconds(kReleaseDuration));
continue;
}
RoundKernel::RoundKernel() : name_(""), current_count_(0) {}
AddressPtr addr_ptr = heap_data_to_release_.front();
heap_data_to_release_.pop();
release_lock.unlock();
std::unique_lock<std::mutex> heap_data_lock(heap_data_mtx_);
if (heap_data_.count(addr_ptr) == 0) {
MS_LOG(ERROR) << "The data is not stored.";
continue;
}
// Manually release unique_ptr data.
heap_data_[addr_ptr].reset(nullptr);
(void)heap_data_.erase(heap_data_.find(addr_ptr));
}
});
}
RoundKernel::~RoundKernel() {
running_ = false;
if (release_thread_.joinable()) {
release_thread_.join();
}
}
RoundKernel::~RoundKernel() {}
void RoundKernel::OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) { return; }
@ -79,16 +49,6 @@ void RoundKernel::FinishIteration() const {
return;
}
void RoundKernel::Release(const AddressPtr &addr_ptr) {
if (addr_ptr == nullptr) {
MS_LOG(WARNING) << "Data to be released is empty.";
return;
}
std::unique_lock<std::mutex> lock(release_mtx_);
heap_data_to_release_.push(addr_ptr);
return;
}
void RoundKernel::set_name(const std::string &name) { name_ = name; }
void RoundKernel::set_stop_timer_cb(const StopTimerCb &timer_stopper) { stop_timer_cb_ = timer_stopper; }
@ -97,36 +57,26 @@ void RoundKernel::set_finish_iteration_cb(const FinishIterCb &finish_iteration_c
finish_iteration_cb_ = finish_iteration_cb;
}
void RoundKernel::GenerateOutput(const std::vector<AddressPtr> &outputs, const void *data, size_t len) {
if (data == nullptr) {
MS_LOG(WARNING) << "The data is nullptr.";
void RoundKernel::GenerateOutput(const std::shared_ptr<ps::core::MessageHandler> &message, const void *data,
size_t len) {
if (message == nullptr) {
MS_LOG(WARNING) << "The message handler is nullptr.";
return;
}
if (outputs.empty()) {
MS_LOG(WARNING) << "Generating output failed. Outputs size is empty.";
if (data == nullptr || len == 0) {
std::string reason = "The output of the round " + name_ + " is empty.";
MS_LOG(WARNING) << reason;
if (!message->SendResponse(reason.c_str(), reason.size())) {
MS_LOG(WARNING) << "Sending response failed.";
return;
}
return;
}
std::unique_ptr<unsigned char[]> output_data = std::make_unique<unsigned char[]>(len);
if (output_data == nullptr) {
MS_LOG(WARNING) << "Output data is nullptr.";
return;
}
size_t dst_size = len;
int ret = memcpy_s(output_data.get(), dst_size, data, len);
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
return;
}
outputs[0]->addr = output_data.get();
outputs[0]->size = len;
std::unique_lock<std::mutex> lock(heap_data_mtx_);
(void)heap_data_.insert(std::make_pair(outputs[0], std::move(output_data)));
IncreaseTotalClientNum();
return;
if (!message->SendResponse(data, len)) {
MS_LOG(WARNING) << "Sending response failed.";
return;
}
}
void RoundKernel::IncreaseTotalClientNum() { total_client_num_ += 1; }

View File

@ -45,22 +45,18 @@ constexpr uint64_t kReleaseDuration = 100;
// For example, the main process of federated learning is:
// startFLJob round->updateModel round->getModel round.
class RoundKernel : virtual public CPUKernel {
class RoundKernel {
public:
RoundKernel();
virtual ~RoundKernel();
// RoundKernel doesn't use InitKernel method of base class CPUKernel to initialize. So implementation of this
// inherited method is empty.
void InitKernel(const CNodePtr &kernel_node) override {}
// Initialize RoundKernel with threshold_count which means that for every iteration, this round needs threshold_count
// messages.
virtual void InitKernel(size_t threshold_count) = 0;
// Launch the round kernel logic to handle the message passed by the communication module.
virtual bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) = 0;
virtual bool Launch(const uint8_t *req_data, size_t len,
const std::shared_ptr<ps::core::MessageHandler> &message) = 0;
// Some rounds could be stateful in a iteration. Reset method resets the status of this round.
virtual bool Reset() = 0;
@ -78,10 +74,6 @@ class RoundKernel : virtual public CPUKernel {
// be called.
void FinishIteration() const;
// Release the response data allocated inside the round kernel.
// Server framework must call this after the response data is sent back.
void Release(const AddressPtr &addr_ptr);
// Set round kernel name, which could be used in round kernel's methods.
void set_name(const std::string &name);
@ -112,36 +104,16 @@ class RoundKernel : virtual public CPUKernel {
protected:
// Generating response data of this round. The data is allocated on the heap to ensure it's not released before sent
// back to worker.
void GenerateOutput(const std::vector<AddressPtr> &outputs, const void *data, size_t len);
void GenerateOutput(const std::shared_ptr<ps::core::MessageHandler> &message, const void *data, size_t len);
// Round kernel's name.
std::string name_;
// The current received message count for this round in this iteration.
size_t current_count_;
// The required received message count for this round in one iteration.
size_t required_count_;
// The reason causes the error in this round kernel.
std::string error_reason_;
StopTimerCb stop_timer_cb_;
FinishIterCb finish_iteration_cb_;
// Members below are used for allocating and releasing response data on the heap.
// To ensure the performance, we use another thread to release data on the heap. So the operation on the data should
// be threadsafe.
std::atomic_bool running_;
std::thread release_thread_;
// Data needs to be released and its mutex;
std::mutex release_mtx_;
std::queue<AddressPtr> heap_data_to_release_;
std::mutex heap_data_mtx_;
std::unordered_map<AddressPtr, std::unique_ptr<unsigned char[]>> heap_data_;
std::atomic<size_t> total_client_num_;
std::atomic<size_t> accept_client_num_;

View File

@ -89,19 +89,13 @@ sigVerifyResult ShareSecretsKernel::VerifySignature(const schema::RequestShareSe
return sigVerifyResult::PASSED;
}
bool ShareSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
bool ShareSecretsKernel::Launch(const uint8_t *req_data, size_t len,
const std::shared_ptr<ps::core::MessageHandler> &message) {
bool response = false;
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
MS_LOG(INFO) << "Launching ShareSecretsKernel, ITERATION NUMBER IS : " << iter_num;
if (inputs.size() != 1 || outputs.size() != 1) {
std::string reason = "inputs or outputs size is invalid.";
MS_LOG(ERROR) << reason;
return false;
}
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
void *req_data = inputs[0]->addr;
if (fbb == nullptr || req_data == nullptr) {
std::string reason = "FBBuilder builder or req_data is nullptr.";
MS_LOG(ERROR) << reason;
@ -112,16 +106,16 @@ bool ShareSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime,
"Current amount for ShareSecretsKernel is enough.",
std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num));
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
flatbuffers::Verifier verifier(req_data, len);
if (!verifier.VerifyBuffer<schema::RequestShareSecrets>()) {
std::string reason = "The schema of RequestShareSecrets is invalid.";
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_RequestError, reason,
std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num));
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
const schema::RequestShareSecrets *share_secrets_req = flatbuffers::GetRoot<schema::RequestShareSecrets>(req_data);
@ -130,7 +124,7 @@ bool ShareSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_RequestError, reason,
std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num));
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
// verify signature
@ -141,7 +135,7 @@ bool ShareSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_RequestError, reason,
std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num));
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
@ -150,7 +144,7 @@ bool ShareSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num));
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
MS_LOG(INFO) << "verify signature passed!";
@ -162,21 +156,21 @@ bool ShareSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std
<< ". client request iteration is " << iter_client;
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime, "ShareSecretsKernel iteration invalid",
std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num));
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
response = cipher_share_->ShareSecrets(SizeToInt(iter_num), share_secrets_req, fbb,
std::to_string(CURRENT_TIME_MILLI.count()));
if (!response) {
MS_LOG(ERROR) << "update secret shares is failed.";
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
if (!CountForShareSecrets(fbb, share_secrets_req, iter_num)) {
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}

View File

@ -38,8 +38,7 @@ class ShareSecretsKernel : public RoundKernel {
ShareSecretsKernel() = default;
~ShareSecretsKernel() override = default;
void InitKernel(size_t required_cnt) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override;
bool Launch(const uint8_t *req_data, size_t len, const std::shared_ptr<ps::core::MessageHandler> &message) override;
bool Reset() override;
private:

View File

@ -51,36 +51,29 @@ void StartFLJobKernel::InitKernel(size_t) {
return;
}
bool StartFLJobKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
bool StartFLJobKernel::Launch(const uint8_t *req_data, size_t len,
const std::shared_ptr<ps::core::MessageHandler> &message) {
MS_LOG(DEBUG) << "Launching StartFLJobKernel kernel.";
if (inputs.size() != 1 || outputs.size() != 1) {
std::string reason = "inputs or outputs size is invalid.";
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, reason.c_str(), reason.size());
return true;
}
void *req_data = inputs[0]->addr;
std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
if (fbb == nullptr || req_data == nullptr) {
std::string reason = "FBBuilder builder or req_data is nullptr.";
MS_LOG(WARNING) << reason;
GenerateOutput(outputs, reason.c_str(), reason.size());
GenerateOutput(message, reason.c_str(), reason.size());
return true;
}
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
flatbuffers::Verifier verifier(req_data, len);
if (!verifier.VerifyBuffer<schema::RequestFLJob>()) {
std::string reason = "The schema of RequestFLJob is invalid.";
BuildStartFLJobRsp(fbb, schema::ResponseCode_RequestError, reason, false, "");
MS_LOG(WARNING) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
ResultCode result_code = ReachThresholdForStartFLJob(fbb);
if (result_code != ResultCode::kSuccess) {
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return ConvertResultCode(result_code);
}
@ -91,17 +84,17 @@ bool StartFLJobKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
fbb, schema::ResponseCode_RequestError, reason, false,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(WARNING) << reason;
GenerateOutput(outputs, reason.c_str(), reason.size());
GenerateOutput(message, reason.c_str(), reason.size());
return true;
}
if (ps::PSContext::instance()->pki_verify()) {
if (!JudgeFLJobCert(fbb, start_fl_job_req)) {
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
if (!StoreKeyAttestation(fbb, start_fl_job_req)) {
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
}
@ -109,7 +102,7 @@ bool StartFLJobKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
DeviceMeta device_meta = CreateDeviceMetadata(start_fl_job_req);
result_code = ReadyForStartFLJob(fbb, device_meta);
if (result_code != ResultCode::kSuccess) {
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return ConvertResultCode(result_code);
}
PBMetadata metadata;
@ -120,7 +113,7 @@ bool StartFLJobKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
BuildStartFLJobRsp(
fbb, schema::ResponseCode_OutOfTime, reason, false,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return update_reason == kNetworkError ? false : true;
}
@ -128,11 +121,11 @@ bool StartFLJobKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
// If calling ReportCount before ReadyForStartFLJob, the result will be inconsistent if the device is not selected.
result_code = CountForStartFLJob(fbb, start_fl_job_req);
if (result_code != ResultCode::kSuccess) {
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return ConvertResultCode(result_code);
}
IncreaseAcceptClientNum();
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}

View File

@ -36,8 +36,7 @@ class StartFLJobKernel : public RoundKernel {
~StartFLJobKernel() override = default;
void InitKernel(size_t threshold_count) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
bool Launch(const uint8_t *req_data, size_t len, const std::shared_ptr<ps::core::MessageHandler> &message) override;
bool Reset() override;
void OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) override;

View File

@ -45,37 +45,29 @@ void UpdateModelKernel::InitKernel(size_t threshold_count) {
LocalMetaStore::GetInstance().put_value(kCtxFedAvgTotalDataSize, kInitialDataSizeSum);
}
bool UpdateModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
bool UpdateModelKernel::Launch(const uint8_t *req_data, size_t len,
const std::shared_ptr<ps::core::MessageHandler> &message) {
MS_LOG(DEBUG) << "Launching UpdateModelKernel kernel.";
if (inputs.size() != 1 || outputs.size() != 1) {
std::string reason = "inputs or outputs size is invalid.";
MS_LOG(WARNING) << reason;
GenerateOutput(outputs, reason.c_str(), reason.size());
return true;
}
void *req_data = inputs[0]->addr;
std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
if (fbb == nullptr || req_data == nullptr) {
std::string reason = "FBBuilder builder or req_data is nullptr.";
MS_LOG(WARNING) << reason;
GenerateOutput(outputs, reason.c_str(), reason.size());
GenerateOutput(message, reason.c_str(), reason.size());
return true;
}
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
flatbuffers::Verifier verifier(req_data, len);
if (!verifier.VerifyBuffer<schema::RequestUpdateModel>()) {
std::string reason = "The schema of RequestUpdateModel is invalid.";
BuildUpdateModelRsp(fbb, schema::ResponseCode_RequestError, reason, "");
MS_LOG(WARNING) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
ResultCode result_code = ReachThresholdForUpdateModel(fbb);
if (result_code != ResultCode::kSuccess) {
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return ConvertResultCode(result_code);
}
@ -84,7 +76,7 @@ bool UpdateModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std:
std::string reason = "Building flatbuffers schema failed for RequestUpdateModel.";
BuildUpdateModelRsp(fbb, schema::ResponseCode_RequestError, reason, "");
MS_LOG(WARNING) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
@ -95,7 +87,7 @@ bool UpdateModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std:
std::string reason = "verify signature failed.";
BuildUpdateModelRsp(fbb, schema::ResponseCode_RequestError, reason, "");
MS_LOG(WARNING) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
@ -103,7 +95,7 @@ bool UpdateModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std:
std::string reason = "verify signature timestamp failed.";
BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason, "");
MS_LOG(WARNING) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
MS_LOG(INFO) << "verify signature passed!";
@ -112,25 +104,25 @@ bool UpdateModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std:
result_code = VerifyUpdateModel(update_model_req, fbb, &device_meta);
if (result_code != ResultCode::kSuccess) {
MS_LOG(WARNING) << "Updating model failed.";
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return ConvertResultCode(result_code);
}
result_code = CountForUpdateModel(fbb, update_model_req);
if (result_code != ResultCode::kSuccess) {
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return ConvertResultCode(result_code);
}
result_code = UpdateModel(update_model_req, fbb, device_meta);
if (result_code != ResultCode::kSuccess) {
MS_LOG(WARNING) << "Updating model failed.";
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return ConvertResultCode(result_code);
}
IncreaseAcceptClientNum();
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
GenerateOutput(message, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}

View File

@ -44,8 +44,7 @@ class UpdateModelKernel : public RoundKernel {
~UpdateModelKernel() override = default;
void InitKernel(size_t threshold_count) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
bool Launch(const uint8_t *req_data, size_t len, const std::shared_ptr<ps::core::MessageHandler> &message) override;
bool Reset() override;
// In some cases, the last updateModel message means this server iteration is finished.

View File

@ -36,28 +36,21 @@ Round::Round(const std::string &name, bool check_timeout, size_t time_window, bo
threshold_count_(threshold_count),
server_num_as_threshold_(server_num_as_threshold) {}
void Round::Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &communicator, const TimeOutCb &timeout_cb,
const FinishIterCb &finish_iteration_cb) {
void Round::RegisterMsgCallBack(const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
MS_EXCEPTION_IF_NULL(communicator);
communicator_ = communicator;
MS_LOG(INFO) << "Round " << name_ << " start initialize.";
communicator_->RegisterMsgCallBack(name_, [&](std::shared_ptr<ps::core::MessageHandler> message) {
MS_ERROR_IF_NULL_WO_RET_VAL(message);
LaunchRoundKernel(message);
});
MS_LOG(INFO) << "Round " << name_ << " register message callback.";
communicator->RegisterMsgCallBack(
name_, [this](std::shared_ptr<ps::core::MessageHandler> message) { LaunchRoundKernel(message); });
}
void Round::Initialize(const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb) {
MS_LOG(INFO) << "Round " << name_ << " start initialize.";
// Callback when the iteration is finished.
finish_iteration_cb_ = [this, finish_iteration_cb](bool, const std::string &) -> void {
std::string reason = "Round " + name_ + " finished! This iteration is valid. Proceed to next iteration.";
finish_iteration_cb(true, reason);
};
// Callback for finalizing the server. This can only be called once.
finalize_cb_ = [&](void) -> void {
MS_ERROR_IF_NULL_WO_RET_VAL(communicator_);
(void)communicator_->Stop();
};
if (check_timeout_) {
iter_timer_ = std::make_shared<IterationTimer>();
MS_EXCEPTION_IF_NULL(iter_timer_);
@ -69,7 +62,7 @@ void Round::Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &commun
});
// 2.Stopping timer callback which will be set to the round kernel.
stop_timer_cb_ = [&](void) -> void {
stop_timer_cb_ = [this](void) -> void {
MS_ERROR_IF_NULL_WO_RET_VAL(iter_timer_);
MS_LOG(INFO) << "Round " << name_ << " kernel stops its timer.";
iter_timer_->Stop();
@ -129,40 +122,17 @@ void Round::BindRoundKernel(const std::shared_ptr<kernel::RoundKernel> &kernel)
void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &message) {
MS_ERROR_IF_NULL_WO_RET_VAL(message);
MS_ERROR_IF_NULL_WO_RET_VAL(kernel_);
MS_ERROR_IF_NULL_WO_RET_VAL(communicator_);
std::string reason = "";
if (!IsServerAvailable(&reason)) {
if (!communicator_->SendResponse(reason.c_str(), reason.size(), message)) {
if (!message->SendResponse(reason.c_str(), reason.size())) {
MS_LOG(WARNING) << "Sending response failed.";
return;
}
return;
}
(void)(Iteration::GetInstance().running_round_num_++);
AddressPtr input = std::make_shared<Address>();
AddressPtr output = std::make_shared<Address>();
MS_ERROR_IF_NULL_WO_RET_VAL(input);
MS_ERROR_IF_NULL_WO_RET_VAL(output);
input->addr = message->data();
input->size = message->len();
bool ret = kernel_->Launch({input}, {}, {output});
if (output->size == 0) {
reason = "The output of the round " + name_ + " is empty.";
MS_LOG(WARNING) << reason;
if (!communicator_->SendResponse(reason.c_str(), reason.size(), message)) {
MS_LOG(ERROR) << "Sending response failed.";
return;
}
return;
}
if (!communicator_->SendResponse(output->addr, output->size, message)) {
MS_LOG(ERROR) << "Sending response failed.";
return;
}
kernel_->Release(output);
bool ret = kernel_->Launch(reinterpret_cast<const uint8_t *>(message->data()), message->len(), message);
// Must send response back no matter what value Launch method returns.
if (!ret) {
reason = "Launching round kernel of round " + name_ + " failed.";

View File

@ -37,8 +37,8 @@ class Round {
bool check_count = false, size_t threshold_count = 8, bool server_num_as_threshold = false);
~Round() = default;
void Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &communicator, const TimeOutCb &timeout_cb,
const FinishIterCb &finish_iteration_cb);
void RegisterMsgCallBack(const std::shared_ptr<ps::core::CommunicatorBase> &communicator);
void Initialize(const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb);
// Reinitialize count service and round kernel of this round after scaling operations are done.
bool ReInitForScaling(uint32_t server_num);
@ -106,8 +106,6 @@ class Round {
// Whether this round uses the server number as its threshold count.
bool server_num_as_threshold_;
std::shared_ptr<ps::core::CommunicatorBase> communicator_;
// The round kernel for this Round.
std::shared_ptr<kernel::RoundKernel> kernel_;
@ -117,7 +115,6 @@ class Round {
// The callbacks which will be set to the round kernel.
StopTimerCb stop_timer_cb_;
FinishIterCb finish_iteration_cb_;
FinalizeCb finalize_cb_;
};
} // namespace server
} // namespace fl

View File

@ -280,7 +280,7 @@ bool AbstractNode::Send(const NodeRole &node_role, const uint32_t &rank_id, cons
message_meta->set_role(node_info_.node_role_);
message_meta->set_user_cmd(command);
auto client = GetOrCreateTcpClient(rank_id);
auto client = GetOrCreateTcpClient(rank_id, node_role);
MS_EXCEPTION_IF_NULL(client);
if (!client->SendMessage(message_meta, Protos::RAW, message.get(), len)) {
MS_LOG(WARNING) << "Client send message failed.";
@ -330,7 +330,7 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
auto send = data.at(it);
auto len = data_lens.at(it);
auto client = GetOrCreateTcpClient(rank_ids.at(it));
auto client = GetOrCreateTcpClient(rank_ids.at(it), node_role);
MS_EXCEPTION_IF_NULL(client);
if (!client->SendMessage(message_meta, Protos::RAW, send.get(), len)) {
MS_LOG(WARNING) << "Client send message failed.";

View File

@ -253,7 +253,7 @@ void HttpMessageHandler::SendResponse() {
evhttp_send_reply(event_request_, resp_code_, "Client", resp_buf_);
}
void HttpMessageHandler::QuickResponse(int code, const unsigned char *body, size_t len) {
void HttpMessageHandler::QuickResponse(int code, const void *body, size_t len) {
MS_EXCEPTION_IF_NULL(event_request_);
MS_EXCEPTION_IF_NULL(body);
MS_EXCEPTION_IF_NULL(resp_buf_);

View File

@ -91,7 +91,7 @@ class HttpMessageHandler {
// Make sure code and all response body has finished set
void SendResponse();
void QuickResponse(int code, const unsigned char *body, size_t len);
void QuickResponse(int code, const void *body, size_t len);
void SimpleResponse(int code, const HttpHeaders &headers, const std::string &body);
void ErrorResponse(int code, const RequestProcessResult &status);

View File

@ -32,7 +32,7 @@ size_t HttpMsgHandler::len() const { return len_; }
bool HttpMsgHandler::SendResponse(const void *data, const size_t &len) {
MS_ERROR_IF_NULL_W_RET_VAL(data, false);
http_msg_->QuickResponse(kHttpSuccess, reinterpret_cast<unsigned char *>(const_cast<void *>(data)), len);
http_msg_->QuickResponse(kHttpSuccess, data, len);
return true;
}
} // namespace core