Modify updateModel kernel for STABLE_PW_ENCRYPT.

This commit is contained in:
jin-xiulang 2021-11-26 11:17:17 +08:00
parent 3c3d10fa31
commit c4e84eb207
3 changed files with 107 additions and 1 deletions

View File

@ -147,6 +147,10 @@ std::vector<uint8_t> ExchangeKeysKernel::GetPubicKeyBytes() {
}
// pubLen has been updated, now get public_key bytes
secret_pubkey_ptr = reinterpret_cast<uint8_t *>(malloc(pubLen));
if (secret_pubkey_ptr == nullptr) {
MS_LOG(ERROR) << "secret_pubkey_ptr is nullptr, malloc failed.";
return {};
}
ret = sPriKeyPtr->GetPublicBytes(&pubLen, secret_pubkey_ptr);
if (ret != 0) {
free(secret_pubkey_ptr);

View File

@ -24,9 +24,11 @@
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#include "fl/worker/fl_worker.h"
#include "fl/armour/secure_protocol/masking.h"
namespace mindspore {
namespace kernel {
constexpr int SECRET_MAX_LEN = 32;
class UpdateModelKernel : public CPUKernel {
public:
UpdateModelKernel() = default;
@ -45,6 +47,10 @@ class UpdateModelKernel : public CPUKernel {
return false;
}
if (encrypt_mode.compare("STABLE_PW_ENCRYPT") == 0) {
EncryptData(inputs);
}
if (!BuildUpdateModelReq(fbb_, inputs)) {
MS_LOG(EXCEPTION) << "Building request for FusedPushWeight failed.";
return false;
@ -92,8 +98,22 @@ class UpdateModelKernel : public CPUKernel {
target_server_rank_ = rank_id_ % server_num_;
fl_name_ = fl::worker::FLWorker::GetInstance().fl_name();
fl_id_ = fl::worker::FLWorker::GetInstance().fl_id();
encrypt_mode = AnfAlgo::GetNodeAttr<string>(kernel_node, "encrypt_mode");
if (encrypt_mode.compare("") != 0 && encrypt_mode.compare("STABLE_PW_ENCRYPT") != 0) {
MS_LOG(EXCEPTION) << "Value Error: the parameter 'encrypt_mode' of updateModel kernel can only be '' or "
"'STABLE_PW_ENCRYPT' until now, but got: "
<< encrypt_mode;
}
MS_LOG(INFO) << "Initializing StartFLJob kernel. fl_name: " << fl_name_ << ", fl_id: " << fl_id_
<< ". Request will be sent to server " << target_server_rank_;
if (encrypt_mode.compare("STABLE_PW_ENCRYPT") == 0) {
MS_LOG(INFO) << "STABLE_PW_ENCRYPT mode is open, model weights will be encrypted before send to server.";
client_keys = fl::worker::FLWorker::GetInstance().public_keys_list();
if (client_keys.size() == 0) {
MS_LOG(EXCEPTION) << "The size of local-stored client_keys_list is 0, please check whether P.ExchangeKeys() "
"and P.GetKeys() have been executed before updateModel.";
}
}
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
for (size_t i = 0; i < input_num; i++) {
@ -156,6 +176,85 @@ class UpdateModelKernel : public CPUKernel {
return true;
}
void EncryptData(const std::vector<AddressPtr> &inputs) {
// calculate the sum of all layer's weight size
size_t total_size = 0;
for (size_t i = 0; i < weight_full_names_.size(); i++) {
total_size += (inputs[i]->size / sizeof(float));
}
// get pairwise encryption noise vector
std::vector<float> noise_vector = GetEncryptNoise(total_size);
// encrypt original data
size_t encrypt_num = 0;
for (size_t i = 0; i < weight_full_names_.size(); i++) {
const std::string &weight_name = weight_full_names_[i];
MS_LOG(INFO) << "Encrypt weights of layer: " << weight_name;
size_t weights_size = inputs[i]->size / sizeof(float);
float *original_data = reinterpret_cast<float *>(inputs[i]->addr);
for (size_t j = 0; j < weights_size; j++) {
original_data[j] += noise_vector[j + encrypt_num];
}
encrypt_num += weights_size;
}
MS_LOG(INFO) << "Encrypt data finished.";
}
// compute the pairwise noise based on local worker's private key and remote workers' public key
std::vector<float> GetEncryptNoise(size_t noise_len) {
std::vector<float> total_noise(noise_len, 0);
int client_num = client_keys.size();
for (int i = 0; i < client_num; i++) {
EncryptPublicKeys public_key_set_i = client_keys[i];
std::string remote_fl_id = public_key_set_i.flID;
// do not need to compute pairwise noise with itself
if (remote_fl_id == fl_id_) {
continue;
}
// get local worker's private key
armour::PrivateKey *local_private_key = fl::worker::FLWorker::GetInstance().secret_pk();
if (local_private_key == nullptr) {
MS_LOG(EXCEPTION) << "Local secret private key is nullptr, get encryption noise failed!";
}
// choose pw_iv and pw_salt for encryption, we choose that of smaller fl_id worker's
std::vector<uint8_t> encrypt_pw_iv;
std::vector<uint8_t> encrypt_pw_salt;
if (fl_id_ < remote_fl_id) {
encrypt_pw_iv = fl::worker::FLWorker::GetInstance().pw_iv();
encrypt_pw_salt = fl::worker::FLWorker::GetInstance().pw_salt();
} else {
encrypt_pw_iv = public_key_set_i.pwIV;
encrypt_pw_salt = public_key_set_i.pwSalt;
}
// get keyAgreement seed
std::vector<uint8_t> remote_public_key = public_key_set_i.publicKey;
armour::PublicKey *pubKey =
armour::KeyAgreement::FromPublicBytes(remote_public_key.data(), remote_public_key.size());
uint8_t secret1[SECRET_MAX_LEN] = {0};
int ret = armour::KeyAgreement::ComputeSharedKey(
local_private_key, pubKey, SECRET_MAX_LEN, encrypt_pw_salt.data(), SizeToInt(encrypt_pw_salt.size()), secret1);
delete pubKey;
if (ret < 0) {
MS_LOG(EXCEPTION) << "Get secret seed failed!";
}
// generate pairwise encryption noise
std::vector<float> noise_i;
if (armour::Masking::GetMasking(&noise_i, noise_len, (const uint8_t *)secret1, SECRET_MAX_LEN,
encrypt_pw_iv.data(), encrypt_pw_iv.size()) < 0) {
MS_LOG(EXCEPTION) << "Get masking noise failed.";
}
int noise_sign = (fl_id_ < remote_fl_id) ? -1 : 1;
for (size_t k = 0; k < noise_len; k++) {
total_noise[k] += noise_sign * noise_i[k];
}
MS_LOG(INFO) << "Generate noise between fl_id: " << fl_id_ << " and fl_id: " << remote_fl_id << " finished.";
}
return total_noise;
}
std::shared_ptr<fl::FBBuilder> fbb_;
uint32_t rank_id_;
uint32_t server_num_;
@ -165,6 +264,8 @@ class UpdateModelKernel : public CPUKernel {
int data_size_;
uint64_t iteration_;
std::vector<std::string> weight_full_names_;
std::string encrypt_mode;
std::vector<EncryptPublicKeys> client_keys;
};
} // namespace kernel
} // namespace mindspore

View File

@ -856,9 +856,10 @@ class UpdateModel(PrimitiveWithInfer):
UpdateModel for federated learning worker.
"""
@prim_attr_register
def __init__(self):
def __init__(self, encrypt_mode=""):
self.add_prim_attr("primitive_target", "CPU")
self.add_prim_attr('side_effect_mem', True)
self.add_prim_attr('encrypt_mode', encrypt_mode)
self.init_prim_io_names(inputs=["weights"], outputs=["result"])
def infer_shape(self, weights):