forked from mindspore-Ecosystem/mindspore
fix bugs of secure aggregation with albert execution failed
This commit is contained in:
parent
43d8687102
commit
1495ef8bf6
|
@ -142,3 +142,6 @@ if(ENABLE_ACL AND NOT ENABLE_D)
|
|||
set(MODE_ASCEND_ACL ON)
|
||||
endif()
|
||||
|
||||
if(ENABLE_CPU AND NOT WIN32)
|
||||
add_compile_definitions(ENABLE_ARMOUR)
|
||||
endif()
|
||||
|
|
|
@ -227,10 +227,6 @@ set(SUB_COMP
|
|||
common debug pybind_api utils vm profiler ps fl
|
||||
)
|
||||
|
||||
if(ENABLE_CPU AND NOT WIN32)
|
||||
add_compile_definitions(ENABLE_ARMOUR)
|
||||
endif()
|
||||
|
||||
foreach(_comp ${SUB_COMP})
|
||||
add_subdirectory(${_comp})
|
||||
string(REPLACE "/" "_" sub ${_comp})
|
||||
|
|
|
@ -262,7 +262,7 @@ bool CipherReconStruct::ReconstructSecrets(const int cur_iterator, const std::st
|
|||
|
||||
bool CipherReconStruct::GetNoiseMasksSum(std::vector<float> *result,
|
||||
const std::map<std::string, std::vector<float>> &client_keys) {
|
||||
float sum[cipher_init_->featuremap_] = {0.0};
|
||||
std::vector<float> sum(cipher_init_->featuremap_, 0.0);
|
||||
for (auto iter = client_keys.begin(); iter != client_keys.end(); iter++) {
|
||||
if (iter->second.size() != cipher_init_->featuremap_) {
|
||||
return false;
|
||||
|
|
|
@ -22,20 +22,6 @@ Random::Random(size_t init_seed) { generator.seed(init_seed); }
|
|||
|
||||
Random::~Random() {}
|
||||
|
||||
void Random::RandUniform(float *array, int size) {
|
||||
std::uniform_real_distribution<double> rand(0, 1);
|
||||
for (int i = 0; i < size; i++) {
|
||||
*(reinterpret_cast<float *>(array) + i) = rand(generator);
|
||||
}
|
||||
}
|
||||
|
||||
void Random::RandNorminal(float *array, int size) {
|
||||
std::normal_distribution<double> randn(0, 1);
|
||||
for (int i = 0; i < size; i++) {
|
||||
*(reinterpret_cast<float *>(array) + i) = randn(generator);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
int Random::GetRandomBytes(unsigned char *secret, int num_bytes) {
|
||||
MS_LOG(ERROR) << "Unsupported feature in Windows platform.";
|
||||
|
@ -49,35 +35,28 @@ int Random::RandomAESCTR(std::vector<float> *noise, int noise_len, const unsigne
|
|||
|
||||
#else
|
||||
int Random::GetRandomBytes(unsigned char *secret, int num_bytes) {
|
||||
int retval = RAND_priv_bytes(secret, RANDOM_LEN);
|
||||
int retval = RAND_priv_bytes(secret, num_bytes);
|
||||
return retval;
|
||||
}
|
||||
|
||||
int Random::RandomAESCTR(std::vector<float> *noise, int noise_len, const unsigned char *seed, int seed_len) {
|
||||
if (seed_len != 16 && seed_len != 32) {
|
||||
std::cout << "seed length must be 16 or 32!" << std::endl;
|
||||
MS_LOG(ERROR) << "seed length must be 16 or 32!";
|
||||
return -1;
|
||||
}
|
||||
int size = noise_len * sizeof(int);
|
||||
unsigned char data[size];
|
||||
unsigned char encrypt_data[size];
|
||||
for (int i = 0; i < size; i++) {
|
||||
data[i] = 0;
|
||||
encrypt_data[i] = 0;
|
||||
}
|
||||
unsigned char ivec[INIT_VEC_SIZE];
|
||||
for (size_t i = 0; i < INIT_VEC_SIZE; i++) {
|
||||
ivec[i] = 0;
|
||||
}
|
||||
int encrypt_len;
|
||||
AESEncrypt encrypt(seed, seed_len, ivec, INIT_VEC_SIZE, AES_CTR);
|
||||
if (encrypt.EncryptData(data, size, encrypt_data, &encrypt_len) != 0) {
|
||||
std::cout << "call encryptData fail!" << std::endl;
|
||||
std::vector<unsigned char> data(size, 0);
|
||||
std::vector<unsigned char> encrypt_data(size, 0);
|
||||
std::vector<unsigned char> ivec(INIT_VEC_SIZE, 0);
|
||||
int encrypt_len = 0;
|
||||
AESEncrypt encrypt(seed, seed_len, ivec.data(), INIT_VEC_SIZE, AES_CTR);
|
||||
if (encrypt.EncryptData(data.data(), size, encrypt_data.data(), &encrypt_len) != 0) {
|
||||
MS_LOG(ERROR) << "call encryptData fail!";
|
||||
return -1;
|
||||
}
|
||||
|
||||
for (int i = 0; i < noise_len; i++) {
|
||||
auto value = *(reinterpret_cast<int32_t *>(encrypt_data) + i);
|
||||
auto value = *(reinterpret_cast<int32_t *>(encrypt_data.data()) + i);
|
||||
noise->emplace_back(static_cast<float>(value) / INT32_MAX);
|
||||
}
|
||||
return 0;
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include <openssl/rand.h>
|
||||
#endif
|
||||
#include "fl/armour/secure_protocol/encrypt.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace armour {
|
||||
|
||||
|
@ -35,10 +36,6 @@ class Random {
|
|||
// use openssl RAND_priv_bytes
|
||||
static int GetRandomBytes(unsigned char *secret, int num_bytes);
|
||||
|
||||
// std::uniform_real_distribution<double> rand(0,1)
|
||||
void RandUniform(float *array, int size);
|
||||
// std::normal_distribution<double> randn(0,1);
|
||||
void RandNorminal(float *array, int size);
|
||||
static int RandomAESCTR(std::vector<float> *noise, int noise_len, const unsigned char *seed, int seed_len);
|
||||
|
||||
private:
|
||||
|
|
Loading…
Reference in New Issue