remove gmp dependency of fl secure aggregation

This commit is contained in:
jin-xiulang 2021-09-02 10:56:32 +08:00
parent 5a851daf2f
commit b116e9cf85
7 changed files with 217 additions and 200 deletions

View File

@ -380,11 +380,8 @@ if(USE_GLOG)
target_link_libraries(_c_expression PRIVATE mindspore::glog)
endif()
find_library(gmp_LIB gmp)
find_library(gmpxx_LIB gmpxx)
if(${CMAKE_SYSTEM_NAME} MATCHES "Linux")
target_link_libraries(mindspore ${gmp_LIB} ${gmpxx_LIB} mindspore::crypto mindspore::ssl)
target_link_libraries(mindspore mindspore::crypto mindspore::ssl)
endif()
if(ENABLE_GPU)
@ -397,8 +394,7 @@ if(ENABLE_GPU)
${CUDA_PATH}/lib64/libcusolver.so
${CUDA_PATH}/lib64/libcufft.so)
if(${CMAKE_SYSTEM_NAME} MATCHES "Linux")
target_link_libraries(_c_expression PRIVATE ${gmp_LIB} ${gmpxx_LIB}
mindspore::crypto mindspore::ssl)
target_link_libraries(_c_expression PRIVATE mindspore::crypto mindspore::ssl)
endif()
if(ENABLE_MPI)
set_target_properties(_ms_mpi PROPERTIES INSTALL_RPATH ${MINDSPORE_RPATH})

View File

@ -23,9 +23,6 @@
#include <vector>
#include <string>
#include <memory>
#ifndef _WIN32
#include <gmp.h>
#endif
#include "proto/ps.pb.h"
#include "utils/log_adapter.h"
#include "fl/armour/secure_protocol/secret_sharing.h"

View File

@ -45,10 +45,12 @@ bool CipherReconStruct::CombineMask(
}
}
MS_LOG(INFO) << "fl_id_src : " << fl_id;
mpz_t prime;
mpz_init(prime);
BIGNUM *prime = BN_new();
if (prime == nullptr) {
return false;
}
auto publicparam_ = CipherInit::GetInstance().GetPublicParams();
mpz_import(prime, PRIME_MAX_LEN, 1, 1, 0, 0, publicparam_->prime);
(void)BN_bin2bn(publicparam_->prime, PRIME_MAX_LEN, prime);
if (iter->second.size() >= cipher_init_->secrets_minnums_) { // combine private key seed.
MS_LOG(INFO) << "start assign secrets shares to public shares ";
for (int i = 0; i < static_cast<int>(cipher_init_->secrets_minnums_); ++i) {
@ -65,10 +67,9 @@ bool CipherReconStruct::CombineMask(
MS_LOG(INFO) << "end assign secrets shares to public shares ";
size_t length;
char secret[SECRET_MAX_LEN] = {0};
uint8_t secret[SECRET_MAX_LEN] = {0};
SecretSharing combine(prime);
if (combine.Combine(static_cast<int>(cipher_init_->secrets_minnums_), *shares_tmp, secret, &length) < 0)
retcode = false;
if (combine.Combine(cipher_init_->secrets_minnums_, *shares_tmp, secret, &length) < 0) retcode = false;
length = SECRET_MAX_LEN;
MS_LOG(INFO) << "combine secrets shares Success.";
@ -303,7 +304,7 @@ void CipherReconStruct::BuildReconstructSecretsRsp(const std::shared_ptr<fl::ser
bool CipherReconStruct::GetSuvNoise(
const std::vector<std::string> &clients_share_list,
const std::map<std::string, std::vector<std::vector<unsigned char>>> &record_public_keys, const string &fl_id,
std::vector<float> *noise, char *secret, int length) {
std::vector<float> *noise, uint8_t *secret, int length) {
for (auto p_key = clients_share_list.begin(); p_key != clients_share_list.end(); ++p_key) {
if (*p_key != fl_id) {
PrivateKey *privKey1 = KeyAgreement::FromPrivateBytes((unsigned char *)secret, length);

View File

@ -61,7 +61,7 @@ class CipherReconStruct {
// get suv noise by computing shares result.
bool GetSuvNoise(const std::vector<std::string> &clients_share_list,
const std::map<std::string, std::vector<std::vector<unsigned char>>> &record_public_keys,
const string &fl_id, std::vector<float> *noise, char *secret, int length);
const string &fl_id, std::vector<float> *noise, uint8_t *secret, int length);
// malloc shares.
bool MallocShares(std::vector<Share *> *shares_tmp, int shares_size);
// delete shares.

View File

@ -25,201 +25,219 @@ void secure_zero(unsigned char *s, size_t n) {
}
#ifndef _WIN32
int GetRandInteger(mpz_t x, mpz_t prim) {
size_t bytes_len = (mpz_sizeinbase(prim, 2) + 8 - 1) / 8;
unsigned char buf[bytes_len];
while (true) {
if (!RAND_bytes(buf, bytes_len)) {
MS_LOG(WARNING) << "Get Rand Integer failed!";
continue;
}
mpz_import(x, bytes_len, 1, 1, 0, 0, buf);
secure_zero(buf, sizeof(buf));
if (mpz_cmp_ui(x, 0) > 0 && mpz_cmp(x, prim) < 0) {
return 0;
}
}
}
int GetRandomPrime(mpz_t prim) {
mpz_t rand;
mpz_init(rand);
int GetPrime(BIGNUM *prim) {
constexpr int byteBits = 8;
const int max_prime_len = SECRET_MAX_LEN + 1;
unsigned char buf[max_prime_len];
if (!RAND_bytes(buf, max_prime_len)) {
MS_LOG(ERROR) << "Get Rand Integer failed!";
const int maxCount = 500;
int count = 1;
int ret = 0;
while (count < maxCount) {
ret = BN_generate_prime_ex(prim, max_prime_len * byteBits, 1, NULL, NULL, NULL);
if (ret == 1) {
break;
}
count++;
}
if (ret != 1 || BN_num_bytes(prim) != max_prime_len) {
MS_LOG(ERROR) << "Get prim failed, get count: " << count;
MS_LOG(ERROR) << "BN_num_bytes: " << BN_num_bytes(prim) << ", max_prime_len: " << max_prime_len;
return -1;
}
mpz_import(rand, max_prime_len, 1, 1, 0, 0, buf);
mpz_nextprime(prim, rand);
mpz_clear(rand);
secure_zero(buf, sizeof(buf));
MS_LOG(INFO) << "Get prim success, get count: " << count;
return 0;
}
void PrintBigInteger(mpz_t x) {
char *tmp = mpz_get_str(NULL, 16, x);
std::string Str = tmp;
MS_LOG(INFO) << Str;
void (*freefunc)(void *, size_t);
mp_get_memory_functions(NULL, NULL, &freefunc);
freefunc(tmp, strlen(tmp) + 1);
}
void PrintBigInteger(mpz_t x, int hex) {
char *tmp = mpz_get_str(NULL, hex, x);
std::string Str = tmp;
MS_LOG(INFO) << Str;
void (*freefunc)(void *, size_t);
mp_get_memory_functions(NULL, NULL, &freefunc);
freefunc(tmp, strlen(tmp) + 1);
}
Share::~Share() {
if (this->data != nullptr) free(this->data);
}
SecretSharing::SecretSharing(mpz_t prim) {
mpz_init(this->prim_);
mpz_set(this->prim_, prim);
}
SecretSharing::~SecretSharing() { mpz_clear(this->prim_); }
void SecretSharing::GetPolyVal(int k, mpz_t y, const mpz_t x, const mpz_t coeff[]) {
int i;
mpz_set_ui(y, 0);
for (i = k - 1; i >= 0; i--) {
field_mult(y, y, x);
field_add(y, y, coeff[i]);
SecretSharing::SecretSharing(BIGNUM *prim) {
if (prim != nullptr) {
this->bn_prim_ = BN_dup(prim);
} else {
this->bn_prim_ = nullptr;
}
}
void SecretSharing::field_invert(mpz_t z, const mpz_t x) { mpz_invert(z, x, this->prim_); }
void SecretSharing::field_add(mpz_t z, const mpz_t x, const mpz_t y) {
mpz_add(z, x, y);
mpz_mod(z, z, this->prim_);
SecretSharing::~SecretSharing() {
if (this->bn_prim_ != nullptr) {
BN_clear_free(this->bn_prim_);
}
}
void SecretSharing::field_mult(mpz_t z, const mpz_t x, const mpz_t y) {
mpz_mul(z, x, y);
mpz_mod(z, z, this->prim_);
bool SecretSharing::field_mult(BIGNUM *z, const BIGNUM *x, const BIGNUM *y, BN_CTX *ctx) {
if (BN_mul(z, x, y, ctx) != 1) {
return false;
}
if (BN_mod(z, z, this->bn_prim_, ctx) != 1) {
return false;
}
return true;
}
int SecretSharing::CalculateShares(const mpz_t coeff[], int k, int n, const std::vector<Share *> &shares) {
mpz_t x, y;
mpz_init(x);
mpz_init(y);
for (int i = 0; i < n; i++) {
mpz_set_ui(x, i + 1);
GetPolyVal(k, y, x, coeff);
shares[i]->index = i + 1;
size_t share_len = (mpz_sizeinbase(y, 2) + 8 - 1) / 8;
shares[i]->data = (unsigned char *)malloc(share_len + 1);
mpz_export(shares[i]->data, &(shares[i]->len), 1, 1, 0, 0, y);
if (shares[i]->len != share_len) {
MS_LOG(ERROR) << "share_len is not equal";
bool SecretSharing::field_add(BIGNUM *z, const BIGNUM *x, const BIGNUM *y, BN_CTX *ctx) {
if (BN_add(z, x, y) != 1) {
return false;
}
if (BN_mod(z, z, this->bn_prim_, ctx) != 1) {
return false;
}
return true;
}
bool SecretSharing::field_sub(BIGNUM *z, const BIGNUM *x, const BIGNUM *y, BN_CTX *ctx) {
if (BN_sub(z, x, y) != 1) {
return false;
}
if (BN_mod(z, z, this->bn_prim_, ctx) != 1) {
return false;
}
return true;
}
bool SecretSharing::GetShare(BIGNUM *x, BIGNUM *share, Share *s_share) {
if (x == nullptr || share == nullptr || s_share == nullptr) {
return false;
}
if (BN_set_word(x, s_share->index) != 1) {
return false;
}
(void)BN_bin2bn(s_share->data, SizeToInt(s_share->len), share);
return true;
}
void SecretSharing::FreeBNVector(std::vector<BIGNUM *> bns) {
for (size_t i = 0; i < bns.size(); i++) {
if (bns[i] != nullptr) {
BN_clear_free(bns[i]);
}
}
}
int SecretSharing::CheckShares(Share *share_i, BIGNUM *x_i, BIGNUM *y_i, BIGNUM *denses_i, BIGNUM *nums_i) {
if (x_i == nullptr || y_i == nullptr || denses_i == nullptr || nums_i == nullptr) {
MS_LOG(ERROR) << "new bn object failed";
return -1;
} else {
if (!GetShare(x_i, y_i, share_i)) {
MS_LOG(ERROR) << "get share failed";
return -1;
}
MS_LOG(INFO) << "share_" << i + 1 << ": ";
PrintBigInteger(y);
}
mpz_clear(x);
mpz_clear(y);
return 0;
}
int SecretSharing::Split(int n, const int k, const char *secret, const size_t length,
const std::vector<Share *> &shares) {
if (k <= 1 || k > n) {
MS_LOG(ERROR) << "invalid parameters";
return -1;
}
if (static_cast<int>(shares.size()) != n) {
MS_LOG(ERROR) << "the size of shares must be equal to n";
return -1;
}
this->degree_ = length * 8;
const int kCoeffLen = k;
mpz_t coeff[kCoeffLen];
int SecretSharing::CheckSum(BIGNUM *sum) {
int ret = 0;
int i = 0;
mpz_init(coeff[i]);
mpz_import(coeff[i], length, 1, 1, 0, 0, secret);
i++;
for (; i < k && ret == 0; i++) {
mpz_init(coeff[i]);
ret = GetRandInteger(coeff[i], this->prim_);
if (ret != 0) {
break;
if (sum == nullptr) {
MS_LOG(ERROR) << "new bn object failed";
ret = -1;
} else {
if (BN_zero(sum) != 1) {
ret = -1;
}
MS_LOG(INFO) << "coeff_" << i << ":";
PrintBigInteger(coeff[i]);
}
if (ret == 0) {
ret = CalculateShares(coeff, k, n, shares);
}
for (i = 0; i < k; i++) mpz_clear(coeff[i]);
return ret;
}
void SecretSharing::GetShare(mpz_t x, mpz_t share, Share *s_share) {
mpz_set_ui(x, s_share->index);
mpz_import(share, s_share->len, 1, 1, 0, 0, s_share->data);
int SecretSharing::LagrangeCal(BIGNUM *nums_j, BIGNUM *x_m, BIGNUM *x_j, BIGNUM *denses_j, BIGNUM *tmp, BN_CTX *ctx) {
if (!field_mult(nums_j, nums_j, x_m, ctx)) {
return -1;
}
if (!field_sub(tmp, x_m, x_j, ctx)) {
return -1;
}
if (!field_mult(denses_j, denses_j, tmp, ctx)) {
return -1;
}
return 0;
}
int SecretSharing::Combine(int k, const std::vector<Share *> &shares, char *secret, size_t *length) {
int SecretSharing::InputCheck(size_t k, const std::vector<Share *> &shares, uint8_t *secret, size_t *length) {
if (secret == nullptr || length == nullptr || k < 1 || shares.size() < k || this->bn_prim_ == nullptr) {
return -1;
}
return 0;
}
void SecretSharing::ReleaseNum(BIGNUM *bigNum) {
if (bigNum != nullptr) {
BN_clear_free(bigNum);
}
}
int SecretSharing::Combine(size_t k, const std::vector<Share *> &shares, uint8_t *secret, size_t *length) {
int check_result = InputCheck(k, shares, secret, length);
if (check_result == -1) return -1;
BN_CTX *ctx = BN_CTX_new();
if (ctx == nullptr) {
MS_LOG(ERROR) << "new bn ctx failed";
return -1;
}
int ret = 0;
mpz_t y[k], x[k], denses[k], nums[k];
int i, j, m;
std::vector<BIGNUM *> y(k);
std::vector<BIGNUM *> x(k);
std::vector<BIGNUM *> denses(k);
std::vector<BIGNUM *> nums(k);
BIGNUM *sum = nullptr;
for (i = 0; i < k; i++) {
mpz_init(x[i]);
mpz_init(y[i]);
mpz_init(denses[i]);
mpz_init(nums[i]);
GetShare(x[i], y[i], shares[i]);
MS_LOG(INFO) << "combine -- share_" << mpz_get_str(NULL, 10, x[i]) << ": ";
PrintBigInteger(y[i]);
MS_LOG(INFO) << "index is : " << shares[i]->index;
MS_LOG(INFO) << "len is %zu " << shares[i]->len;
for (size_t i = 0; i < k; i++) {
x[i] = BN_new();
y[i] = BN_new();
denses[i] = BN_new();
nums[i] = BN_new();
ret = CheckShares(shares[i], x[i], y[i], denses[i], nums[i]);
if (ret == -1) break;
}
mpz_t sum;
mpz_init(sum);
mpz_set_ui(sum, 0);
for (j = 0; j < k; j++) {
mpz_set_ui(denses[j], 1);
mpz_set_ui(nums[j], 1);
mpz_t tmp;
mpz_init(tmp);
for (m = 0; m < k; m++) {
if (m != j) {
field_mult(nums[j], nums[j], x[m]);
mpz_mul_si(tmp, x[j], -1);
field_add(tmp, x[m], tmp);
field_mult(denses[j], denses[j], tmp);
if (ret != -1) {
sum = BN_new();
ret = CheckSum(sum);
}
if (ret != -1) {
for (size_t j = 0; j < k; j++) {
if (BN_one(denses[j]) != 1 || BN_one(nums[j]) != 1) {
ret = -1;
break;
}
BIGNUM *tmp = BN_new();
if (tmp == nullptr) {
MS_LOG(ERROR) << "new bn object failed";
ret = -1;
break;
}
for (size_t m = 0; m < k; m++) {
if (m != j) {
ret = LagrangeCal(nums[j], x[m], x[j], denses[j], tmp, ctx);
if (ret == -1) break;
}
}
(void)BN_mod_inverse(tmp, denses[j], this->bn_prim_, ctx);
if (!field_mult(tmp, tmp, nums[j], ctx)) {
ret = -1;
break;
}
if (!field_mult(tmp, tmp, y[j], ctx)) {
ret = -1;
break;
}
if (!field_add(sum, sum, tmp, ctx)) {
ret = -1;
break;
}
BN_clear_free(tmp);
}
field_invert(tmp, denses[j]);
field_mult(tmp, tmp, nums[j]);
field_mult(tmp, tmp, y[j]);
field_add(sum, sum, tmp);
mpz_clear(tmp);
}
mpz_export(secret, length, 1, 1, 0, 0, sum);
PrintBigInteger(sum);
mpz_clear(sum);
for (i = 0; i < k; i++) {
mpz_clear(x[i]);
mpz_clear(y[i]);
mpz_clear(nums[i]);
mpz_clear(denses[i]);
*length = BN_bn2bin(sum, secret);
}
BN_CTX_free(ctx);
ReleaseNum(sum);
FreeBNVector(x);
FreeBNVector(y);
FreeBNVector(denses);
FreeBNVector(nums);
return ret;
}
#endif

View File

@ -17,12 +17,12 @@
#ifndef MINDSPORE_SECRET_SHARING_H
#define MINDSPORE_SECRET_SHARING_H
#ifndef _WIN32
#include <gmp.h>
#include "openssl/rand.h"
#include "openssl/bn.h"
#endif
#include <string>
#include <vector>
#include "utils/log_adapter.h"
#include "fl/server/common.h"
namespace mindspore {
namespace armour {
@ -37,36 +37,35 @@ struct Share {
};
#ifndef _WIN32
void secure_zero(void *s, size_t);
int GetRandInteger(mpz_t x, mpz_t prim);
int GetRandomPrime(mpz_t prim);
void PrintBigInteger(mpz_t x);
void PrintBigInteger(mpz_t x, int hex);
void secure_zero(uint8_t *s, size_t);
int GetPrime(BIGNUM *prim);
class SecretSharing {
public:
explicit SecretSharing(mpz_t prim);
explicit SecretSharing(BIGNUM *prim);
~SecretSharing();
// split the input secret into multiple shares
int Split(int n, const int k, const char *secret, size_t length, const std::vector<Share *> &shares);
// reconstruct the secret from multiple shares
int Combine(int k, const std::vector<Share *> &shares, char *secret, size_t *length);
int Combine(size_t k, const std::vector<Share *> &shares, uint8_t *secret, size_t *length);
int CheckShares(Share *share_i, BIGNUM *x_i, BIGNUM *y_i, BIGNUM *denses_i, BIGNUM *nums_i);
int CheckSum(BIGNUM *sum);
int LagrangeCal(BIGNUM *nums_j, BIGNUM *x_m, BIGNUM *x_j, BIGNUM *denses_j, BIGNUM *tmp, BN_CTX *ctx);
int InputCheck(size_t k, const std::vector<Share *> &shares, uint8_t *secret, size_t *length);
void ReleaseNum(BIGNUM *bigNum);
private:
mpz_t prim_;
BIGNUM *bn_prim_;
size_t degree_;
// calculate shares from a polynomial
int CalculateShares(const mpz_t coeff[], int k, int n, const std::vector<Share *> &shares);
// inversion in finite field
void field_invert(mpz_t z, const mpz_t x);
// addition in finite field
void field_add(mpz_t z, const mpz_t x, const mpz_t y);
bool field_add(BIGNUM *z, const BIGNUM *x, const BIGNUM *y, BN_CTX *ctx);
// multiplication in finite field
void field_mult(mpz_t z, const mpz_t x, const mpz_t y);
// evaluate polynomial at x
void GetPolyVal(int k, mpz_t y, const mpz_t x, const mpz_t coeff[]);
// convert secret sharing from Share type to mpz_t type
void GetShare(mpz_t x, mpz_t share, Share *s_share);
bool field_mult(BIGNUM *z, const BIGNUM *x, const BIGNUM *y, BN_CTX *ctx);
// subtraction in finite field
bool field_sub(BIGNUM *z, const BIGNUM *x, const BIGNUM *y, BN_CTX *ctx);
// convert secret sharing from Share type to BIGNUM type
bool GetShare(BIGNUM *x, BIGNUM *share, Share *s_share);
void FreeBNVector(std::vector<BIGNUM *> bns);
};
#endif

View File

@ -284,13 +284,19 @@ void Server::InitCipher() {
float dp_norm_clip = ps::PSContext::instance()->dp_norm_clip();
std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
mpz_t prim;
mpz_init(prim);
mindspore::armour::GetRandomPrime(prim);
mindspore::armour::PrintBigInteger(prim, 16);
BIGNUM *prim = BN_new();
if (prim == nullptr) {
MS_LOG(EXCEPTION) << "new bn failed";
}
mindspore::armour::GetPrime(prim);
MS_LOG(INFO) << "prime" << BN_bn2hex(prim);
(void)BN_bn2bin(prim, reinterpret_cast<uint8_t *>(cipher_prime));
if (prim != nullptr) {
BN_clear_free(prim);
}
size_t len_cipher_prime;
mpz_export((unsigned char *)cipher_prime, &len_cipher_prime, sizeof(unsigned char), 1, 0, 0, prim);
mindspore::armour::CipherPublicPara param;
param.g = cipher_g;
param.t = cipher_t;