From 1495ef8bf63fc9bb1dba3fb101203495c870e3b1 Mon Sep 17 00:00:00 2001 From: yangyuan Date: Wed, 14 Jul 2021 15:12:27 +0800 Subject: [PATCH] fix bugs of secure aggregation with albert execution failed --- cmake/options.cmake | 3 ++ mindspore/ccsrc/CMakeLists.txt | 4 -- .../fl/armour/cipher/cipher_reconstruct.cc | 2 +- .../ccsrc/fl/armour/secure_protocol/random.cc | 41 +++++-------------- .../ccsrc/fl/armour/secure_protocol/random.h | 5 +-- 5 files changed, 15 insertions(+), 40 deletions(-) diff --git a/cmake/options.cmake b/cmake/options.cmake index 120ce114cbd..2a201475812 100644 --- a/cmake/options.cmake +++ b/cmake/options.cmake @@ -142,3 +142,6 @@ if(ENABLE_ACL AND NOT ENABLE_D) set(MODE_ASCEND_ACL ON) endif() +if(ENABLE_CPU AND NOT WIN32) + add_compile_definitions(ENABLE_ARMOUR) +endif() diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index 2e1be8c3339..76a50095e99 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -227,10 +227,6 @@ set(SUB_COMP common debug pybind_api utils vm profiler ps fl ) -if(ENABLE_CPU AND NOT WIN32) - add_compile_definitions(ENABLE_ARMOUR) -endif() - foreach(_comp ${SUB_COMP}) add_subdirectory(${_comp}) string(REPLACE "/" "_" sub ${_comp}) diff --git a/mindspore/ccsrc/fl/armour/cipher/cipher_reconstruct.cc b/mindspore/ccsrc/fl/armour/cipher/cipher_reconstruct.cc index 52daf7e421b..51a9406f2fe 100644 --- a/mindspore/ccsrc/fl/armour/cipher/cipher_reconstruct.cc +++ b/mindspore/ccsrc/fl/armour/cipher/cipher_reconstruct.cc @@ -262,7 +262,7 @@ bool CipherReconStruct::ReconstructSecrets(const int cur_iterator, const std::st bool CipherReconStruct::GetNoiseMasksSum(std::vector *result, const std::map> &client_keys) { - float sum[cipher_init_->featuremap_] = {0.0}; + std::vector sum(cipher_init_->featuremap_, 0.0); for (auto iter = client_keys.begin(); iter != client_keys.end(); iter++) { if (iter->second.size() != cipher_init_->featuremap_) { return false; diff --git a/mindspore/ccsrc/fl/armour/secure_protocol/random.cc b/mindspore/ccsrc/fl/armour/secure_protocol/random.cc index e43152c471f..898955500ea 100644 --- a/mindspore/ccsrc/fl/armour/secure_protocol/random.cc +++ b/mindspore/ccsrc/fl/armour/secure_protocol/random.cc @@ -22,20 +22,6 @@ Random::Random(size_t init_seed) { generator.seed(init_seed); } Random::~Random() {} -void Random::RandUniform(float *array, int size) { - std::uniform_real_distribution rand(0, 1); - for (int i = 0; i < size; i++) { - *(reinterpret_cast(array) + i) = rand(generator); - } -} - -void Random::RandNorminal(float *array, int size) { - std::normal_distribution randn(0, 1); - for (int i = 0; i < size; i++) { - *(reinterpret_cast(array) + i) = randn(generator); - } -} - #ifdef _WIN32 int Random::GetRandomBytes(unsigned char *secret, int num_bytes) { MS_LOG(ERROR) << "Unsupported feature in Windows platform."; @@ -49,35 +35,28 @@ int Random::RandomAESCTR(std::vector *noise, int noise_len, const unsigne #else int Random::GetRandomBytes(unsigned char *secret, int num_bytes) { - int retval = RAND_priv_bytes(secret, RANDOM_LEN); + int retval = RAND_priv_bytes(secret, num_bytes); return retval; } int Random::RandomAESCTR(std::vector *noise, int noise_len, const unsigned char *seed, int seed_len) { if (seed_len != 16 && seed_len != 32) { - std::cout << "seed length must be 16 or 32!" << std::endl; + MS_LOG(ERROR) << "seed length must be 16 or 32!"; return -1; } int size = noise_len * sizeof(int); - unsigned char data[size]; - unsigned char encrypt_data[size]; - for (int i = 0; i < size; i++) { - data[i] = 0; - encrypt_data[i] = 0; - } - unsigned char ivec[INIT_VEC_SIZE]; - for (size_t i = 0; i < INIT_VEC_SIZE; i++) { - ivec[i] = 0; - } - int encrypt_len; - AESEncrypt encrypt(seed, seed_len, ivec, INIT_VEC_SIZE, AES_CTR); - if (encrypt.EncryptData(data, size, encrypt_data, &encrypt_len) != 0) { - std::cout << "call encryptData fail!" << std::endl; + std::vector data(size, 0); + std::vector encrypt_data(size, 0); + std::vector ivec(INIT_VEC_SIZE, 0); + int encrypt_len = 0; + AESEncrypt encrypt(seed, seed_len, ivec.data(), INIT_VEC_SIZE, AES_CTR); + if (encrypt.EncryptData(data.data(), size, encrypt_data.data(), &encrypt_len) != 0) { + MS_LOG(ERROR) << "call encryptData fail!"; return -1; } for (int i = 0; i < noise_len; i++) { - auto value = *(reinterpret_cast(encrypt_data) + i); + auto value = *(reinterpret_cast(encrypt_data.data()) + i); noise->emplace_back(static_cast(value) / INT32_MAX); } return 0; diff --git a/mindspore/ccsrc/fl/armour/secure_protocol/random.h b/mindspore/ccsrc/fl/armour/secure_protocol/random.h index 7b435c6d74c..efabcc2b4e0 100644 --- a/mindspore/ccsrc/fl/armour/secure_protocol/random.h +++ b/mindspore/ccsrc/fl/armour/secure_protocol/random.h @@ -23,6 +23,7 @@ #include #endif #include "fl/armour/secure_protocol/encrypt.h" + namespace mindspore { namespace armour { @@ -35,10 +36,6 @@ class Random { // use openssl RAND_priv_bytes static int GetRandomBytes(unsigned char *secret, int num_bytes); - // std::uniform_real_distribution rand(0,1) - void RandUniform(float *array, int size); - // std::normal_distribution randn(0,1); - void RandNorminal(float *array, int size); static int RandomAESCTR(std::vector *noise, int noise_len, const unsigned char *seed, int seed_len); private: