forked from mindspore-Ecosystem/mindspore
Modify updateModel kernel for STABLE_PW_ENCRYPT.
This commit is contained in:
parent
3c3d10fa31
commit
c4e84eb207
|
@ -147,6 +147,10 @@ std::vector<uint8_t> ExchangeKeysKernel::GetPubicKeyBytes() {
|
|||
}
|
||||
// pubLen has been updated, now get public_key bytes
|
||||
secret_pubkey_ptr = reinterpret_cast<uint8_t *>(malloc(pubLen));
|
||||
if (secret_pubkey_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "secret_pubkey_ptr is nullptr, malloc failed.";
|
||||
return {};
|
||||
}
|
||||
ret = sPriKeyPtr->GetPublicBytes(&pubLen, secret_pubkey_ptr);
|
||||
if (ret != 0) {
|
||||
free(secret_pubkey_ptr);
|
||||
|
|
|
@ -24,9 +24,11 @@
|
|||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
#include "fl/worker/fl_worker.h"
|
||||
#include "fl/armour/secure_protocol/masking.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr int SECRET_MAX_LEN = 32;
|
||||
class UpdateModelKernel : public CPUKernel {
|
||||
public:
|
||||
UpdateModelKernel() = default;
|
||||
|
@ -45,6 +47,10 @@ class UpdateModelKernel : public CPUKernel {
|
|||
return false;
|
||||
}
|
||||
|
||||
if (encrypt_mode.compare("STABLE_PW_ENCRYPT") == 0) {
|
||||
EncryptData(inputs);
|
||||
}
|
||||
|
||||
if (!BuildUpdateModelReq(fbb_, inputs)) {
|
||||
MS_LOG(EXCEPTION) << "Building request for FusedPushWeight failed.";
|
||||
return false;
|
||||
|
@ -92,8 +98,22 @@ class UpdateModelKernel : public CPUKernel {
|
|||
target_server_rank_ = rank_id_ % server_num_;
|
||||
fl_name_ = fl::worker::FLWorker::GetInstance().fl_name();
|
||||
fl_id_ = fl::worker::FLWorker::GetInstance().fl_id();
|
||||
encrypt_mode = AnfAlgo::GetNodeAttr<string>(kernel_node, "encrypt_mode");
|
||||
if (encrypt_mode.compare("") != 0 && encrypt_mode.compare("STABLE_PW_ENCRYPT") != 0) {
|
||||
MS_LOG(EXCEPTION) << "Value Error: the parameter 'encrypt_mode' of updateModel kernel can only be '' or "
|
||||
"'STABLE_PW_ENCRYPT' until now, but got: "
|
||||
<< encrypt_mode;
|
||||
}
|
||||
MS_LOG(INFO) << "Initializing StartFLJob kernel. fl_name: " << fl_name_ << ", fl_id: " << fl_id_
|
||||
<< ". Request will be sent to server " << target_server_rank_;
|
||||
if (encrypt_mode.compare("STABLE_PW_ENCRYPT") == 0) {
|
||||
MS_LOG(INFO) << "STABLE_PW_ENCRYPT mode is open, model weights will be encrypted before send to server.";
|
||||
client_keys = fl::worker::FLWorker::GetInstance().public_keys_list();
|
||||
if (client_keys.size() == 0) {
|
||||
MS_LOG(EXCEPTION) << "The size of local-stored client_keys_list is 0, please check whether P.ExchangeKeys() "
|
||||
"and P.GetKeys() have been executed before updateModel.";
|
||||
}
|
||||
}
|
||||
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
for (size_t i = 0; i < input_num; i++) {
|
||||
|
@ -156,6 +176,85 @@ class UpdateModelKernel : public CPUKernel {
|
|||
return true;
|
||||
}
|
||||
|
||||
void EncryptData(const std::vector<AddressPtr> &inputs) {
|
||||
// calculate the sum of all layer's weight size
|
||||
size_t total_size = 0;
|
||||
for (size_t i = 0; i < weight_full_names_.size(); i++) {
|
||||
total_size += (inputs[i]->size / sizeof(float));
|
||||
}
|
||||
// get pairwise encryption noise vector
|
||||
std::vector<float> noise_vector = GetEncryptNoise(total_size);
|
||||
|
||||
// encrypt original data
|
||||
size_t encrypt_num = 0;
|
||||
for (size_t i = 0; i < weight_full_names_.size(); i++) {
|
||||
const std::string &weight_name = weight_full_names_[i];
|
||||
MS_LOG(INFO) << "Encrypt weights of layer: " << weight_name;
|
||||
size_t weights_size = inputs[i]->size / sizeof(float);
|
||||
float *original_data = reinterpret_cast<float *>(inputs[i]->addr);
|
||||
for (size_t j = 0; j < weights_size; j++) {
|
||||
original_data[j] += noise_vector[j + encrypt_num];
|
||||
}
|
||||
encrypt_num += weights_size;
|
||||
}
|
||||
MS_LOG(INFO) << "Encrypt data finished.";
|
||||
}
|
||||
|
||||
// compute the pairwise noise based on local worker's private key and remote workers' public key
|
||||
std::vector<float> GetEncryptNoise(size_t noise_len) {
|
||||
std::vector<float> total_noise(noise_len, 0);
|
||||
int client_num = client_keys.size();
|
||||
for (int i = 0; i < client_num; i++) {
|
||||
EncryptPublicKeys public_key_set_i = client_keys[i];
|
||||
std::string remote_fl_id = public_key_set_i.flID;
|
||||
// do not need to compute pairwise noise with itself
|
||||
if (remote_fl_id == fl_id_) {
|
||||
continue;
|
||||
}
|
||||
// get local worker's private key
|
||||
armour::PrivateKey *local_private_key = fl::worker::FLWorker::GetInstance().secret_pk();
|
||||
if (local_private_key == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Local secret private key is nullptr, get encryption noise failed!";
|
||||
}
|
||||
|
||||
// choose pw_iv and pw_salt for encryption, we choose that of smaller fl_id worker's
|
||||
std::vector<uint8_t> encrypt_pw_iv;
|
||||
std::vector<uint8_t> encrypt_pw_salt;
|
||||
if (fl_id_ < remote_fl_id) {
|
||||
encrypt_pw_iv = fl::worker::FLWorker::GetInstance().pw_iv();
|
||||
encrypt_pw_salt = fl::worker::FLWorker::GetInstance().pw_salt();
|
||||
} else {
|
||||
encrypt_pw_iv = public_key_set_i.pwIV;
|
||||
encrypt_pw_salt = public_key_set_i.pwSalt;
|
||||
}
|
||||
|
||||
// get keyAgreement seed
|
||||
std::vector<uint8_t> remote_public_key = public_key_set_i.publicKey;
|
||||
armour::PublicKey *pubKey =
|
||||
armour::KeyAgreement::FromPublicBytes(remote_public_key.data(), remote_public_key.size());
|
||||
uint8_t secret1[SECRET_MAX_LEN] = {0};
|
||||
int ret = armour::KeyAgreement::ComputeSharedKey(
|
||||
local_private_key, pubKey, SECRET_MAX_LEN, encrypt_pw_salt.data(), SizeToInt(encrypt_pw_salt.size()), secret1);
|
||||
delete pubKey;
|
||||
if (ret < 0) {
|
||||
MS_LOG(EXCEPTION) << "Get secret seed failed!";
|
||||
}
|
||||
|
||||
// generate pairwise encryption noise
|
||||
std::vector<float> noise_i;
|
||||
if (armour::Masking::GetMasking(&noise_i, noise_len, (const uint8_t *)secret1, SECRET_MAX_LEN,
|
||||
encrypt_pw_iv.data(), encrypt_pw_iv.size()) < 0) {
|
||||
MS_LOG(EXCEPTION) << "Get masking noise failed.";
|
||||
}
|
||||
int noise_sign = (fl_id_ < remote_fl_id) ? -1 : 1;
|
||||
for (size_t k = 0; k < noise_len; k++) {
|
||||
total_noise[k] += noise_sign * noise_i[k];
|
||||
}
|
||||
MS_LOG(INFO) << "Generate noise between fl_id: " << fl_id_ << " and fl_id: " << remote_fl_id << " finished.";
|
||||
}
|
||||
return total_noise;
|
||||
}
|
||||
|
||||
std::shared_ptr<fl::FBBuilder> fbb_;
|
||||
uint32_t rank_id_;
|
||||
uint32_t server_num_;
|
||||
|
@ -165,6 +264,8 @@ class UpdateModelKernel : public CPUKernel {
|
|||
int data_size_;
|
||||
uint64_t iteration_;
|
||||
std::vector<std::string> weight_full_names_;
|
||||
std::string encrypt_mode;
|
||||
std::vector<EncryptPublicKeys> client_keys;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -856,9 +856,10 @@ class UpdateModel(PrimitiveWithInfer):
|
|||
UpdateModel for federated learning worker.
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
def __init__(self, encrypt_mode=""):
|
||||
self.add_prim_attr("primitive_target", "CPU")
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
self.add_prim_attr('encrypt_mode', encrypt_mode)
|
||||
self.init_prim_io_names(inputs=["weights"], outputs=["result"])
|
||||
|
||||
def infer_shape(self, weights):
|
||||
|
|
Loading…
Reference in New Issue