forked from mindspore-Ecosystem/mindspore
!19008 disable code invoking openssl APIs in Windows platform
Merge pull request !19008 from yyuse/fl_mpc_ci
This commit is contained in:
commit
fc884e44b6
|
@ -384,6 +384,13 @@ if(USE_GLOG)
|
|||
target_link_libraries(_c_expression PRIVATE mindspore::glog)
|
||||
endif()
|
||||
|
||||
find_library(gmp_LIB gmp)
|
||||
find_library(gmpxx_LIB gmpxx)
|
||||
|
||||
if(${CMAKE_SYSTEM_NAME} MATCHES "Linux")
|
||||
target_link_libraries(mindspore ${gmp_LIB} ${gmpxx_LIB} mindspore::crypto mindspore::ssl)
|
||||
endif()
|
||||
|
||||
if(ENABLE_GPU)
|
||||
message("add gpu lib to c_expression")
|
||||
target_link_libraries(_c_expression PRIVATE gpu_cuda_lib gpu_queue cublas
|
||||
|
@ -393,6 +400,10 @@ if(ENABLE_GPU)
|
|||
${CUDA_PATH}/lib64/stubs/libcuda.so
|
||||
${CUDA_PATH}/lib64/libcusolver.so
|
||||
${CUDA_PATH}/lib64/libcufft.so)
|
||||
if(${CMAKE_SYSTEM_NAME} MATCHES "Linux")
|
||||
target_link_libraries(_c_expression PRIVATE ${gmp_LIB} ${gmpxx_LIB}
|
||||
mindspore::crypto mindspore::ssl)
|
||||
endif()
|
||||
if(ENABLE_MPI)
|
||||
set_target_properties(_ms_mpi PROPERTIES INSTALL_RPATH ${MINDSPORE_RPATH})
|
||||
endif()
|
||||
|
|
|
@ -1,5 +1,14 @@
|
|||
file(GLOB_RECURSE ARMOUR_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
|
||||
if(NOT ENABLE_CPU OR WIN32)
|
||||
list(REMOVE_ITEM ARMOUR_FILES "cipher/cipher_init.cc")
|
||||
list(REMOVE_ITEM ARMOUR_FILES "cipher/cipher_keys.cc")
|
||||
list(REMOVE_ITEM ARMOUR_FILES "cipher/cipher_meta_storage.cc")
|
||||
list(REMOVE_ITEM ARMOUR_FILES "cipher/cipher_reconstruct.cc")
|
||||
list(REMOVE_ITEM ARMOUR_FILES "cipher/cipher_shares.cc")
|
||||
list(REMOVE_ITEM ARMOUR_FILES "cipher/cipher_unmask.cc")
|
||||
endif()
|
||||
|
||||
set(SERVER_FLATBUFFER_OUTPUT "${CMAKE_BINARY_DIR}/schema")
|
||||
set(FBS_FILES
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../schema/cipher.fbs
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "armour/secure_protocol/encrypt.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace armour {
|
||||
|
||||
|
@ -32,17 +33,32 @@ AESEncrypt::AESEncrypt(const unsigned char *key, int key_len, unsigned char *ive
|
|||
|
||||
AESEncrypt::~AESEncrypt() {}
|
||||
|
||||
#if defined(_WIN32)
|
||||
int AESEncrypt::EncryptData(const unsigned char *data, const int len, unsigned char *encrypt_data, int *encrypt_len) {
|
||||
MS_LOG(ERROR) << "Unsupported feature in Windows platform.";
|
||||
return -1;
|
||||
}
|
||||
|
||||
int AESEncrypt::DecryptData(const unsigned char *encrypt_data, const int encrypt_len, unsigned char *data, int *len) {
|
||||
MS_LOG(ERROR) << "Unsupported feature in Windows platform.";
|
||||
return -1;
|
||||
}
|
||||
|
||||
#else
|
||||
int AESEncrypt::EncryptData(const unsigned char *data, const int len, unsigned char *encrypt_data, int *encrypt_len) {
|
||||
int ret;
|
||||
if (privKeyLen != KEY_STEP_MIN && privKeyLen != KEY_STEP_MAX) {
|
||||
std::cout << "key length must be 16 or 32!" << std::endl;
|
||||
MS_LOG(ERROR) << "key length must be 16 or 32!";
|
||||
return -1;
|
||||
}
|
||||
if (iVecLen != INIT_VEC_SIZE) {
|
||||
MS_LOG(ERROR) << "initial vector size must be 16!";
|
||||
return -1;
|
||||
}
|
||||
assert(iVecLen == INIT_VEC_SIZE);
|
||||
if (aesMode == AES_CBC || aesMode == AES_CTR) {
|
||||
ret = evp_aes_encrypt(data, len, privKey, iVec, encrypt_data, encrypt_len);
|
||||
} else {
|
||||
std::cout << "Please use CBC mode or CTR mode, the other modes are not supported!\n" << std::endl;
|
||||
MS_LOG(ERROR) << "Please use CBC mode or CTR mode, the other modes are not supported!";
|
||||
ret = -1;
|
||||
}
|
||||
if (ret != 0) {
|
||||
|
@ -54,14 +70,17 @@ int AESEncrypt::EncryptData(const unsigned char *data, const int len, unsigned c
|
|||
int AESEncrypt::DecryptData(const unsigned char *encrypt_data, const int encrypt_len, unsigned char *data, int *len) {
|
||||
int ret = 0;
|
||||
if (privKeyLen != KEY_STEP_MIN && privKeyLen != KEY_STEP_MAX) {
|
||||
std::cout << "key length must be 16 or 32!" << std::endl;
|
||||
MS_LOG(ERROR) << "key length must be 16 or 32!";
|
||||
return -1;
|
||||
}
|
||||
if (iVecLen != INIT_VEC_SIZE) {
|
||||
MS_LOG(ERROR) << "initial vector size must be 16!";
|
||||
return -1;
|
||||
}
|
||||
assert(iVecLen == INIT_VEC_SIZE);
|
||||
if (aesMode == AES_CBC || aesMode == AES_CTR) {
|
||||
ret = evp_aes_decrypt(encrypt_data, encrypt_len, privKey, iVec, data, len);
|
||||
} else {
|
||||
std::cout << "Please use CBC mode or CTR mode, the other modes are not supported!" << std::endl;
|
||||
MS_LOG(ERROR) << "Please use CBC mode or CTR mode, the other modes are not supported!";
|
||||
}
|
||||
if (ret != 1) {
|
||||
return -1;
|
||||
|
@ -83,11 +102,11 @@ int AESEncrypt::evp_aes_encrypt(const unsigned char *data, const int len, const
|
|||
ret = EVP_EncryptInit_ex(ctx, EVP_aes_256_cbc(), NULL, key, ivec);
|
||||
break;
|
||||
default:
|
||||
std::cout << "key length is incorrect!" << std::endl;
|
||||
MS_LOG(ERROR) << "key length is incorrect!";
|
||||
ret = -1;
|
||||
}
|
||||
if (ret != 1) {
|
||||
std::cout << "EVP_EncryptInit_ex CBC fail!" << std::endl;
|
||||
MS_LOG(ERROR) << "EVP_EncryptInit_ex CBC fail!";
|
||||
return -1;
|
||||
}
|
||||
EVP_CIPHER_CTX_set_key_length(ctx, EVP_MAX_KEY_LENGTH);
|
||||
|
@ -101,26 +120,26 @@ int AESEncrypt::evp_aes_encrypt(const unsigned char *data, const int len, const
|
|||
ret = EVP_EncryptInit_ex(ctx, EVP_aes_256_ctr(), NULL, key, ivec);
|
||||
break;
|
||||
default:
|
||||
std::cout << "key length is incorrect!" << std::endl;
|
||||
MS_LOG(ERROR) << "key length is incorrect!";
|
||||
ret = -1;
|
||||
}
|
||||
if (ret != 1) {
|
||||
std::cout << "EVP_EncryptInit_ex CTR fail!" << std::endl;
|
||||
MS_LOG(ERROR) << "EVP_EncryptInit_ex CTR fail!";
|
||||
return -1;
|
||||
}
|
||||
} else {
|
||||
std::cout << "Unsupported AES mode" << std::endl;
|
||||
MS_LOG(ERROR) << "Unsupported AES mode";
|
||||
return -1;
|
||||
}
|
||||
ret = EVP_EncryptUpdate(ctx, encrypt_data, &out_len, data, len);
|
||||
if (ret != 1) {
|
||||
std::cout << "EVP_EncryptUpdate fail!" << std::endl;
|
||||
MS_LOG(ERROR) << "EVP_EncryptUpdate fail!";
|
||||
return -1;
|
||||
}
|
||||
*encrypt_len = out_len;
|
||||
ret = EVP_EncryptFinal_ex(ctx, encrypt_data + *encrypt_len, &out_len);
|
||||
if (ret != 1) {
|
||||
std::cout << "EVP_EncryptFinal_ex fail!" << std::endl;
|
||||
MS_LOG(ERROR) << "EVP_EncryptFinal_ex fail!";
|
||||
return -1;
|
||||
}
|
||||
*encrypt_len += out_len;
|
||||
|
@ -142,7 +161,7 @@ int AESEncrypt::evp_aes_decrypt(const unsigned char *encrypt_data, const int len
|
|||
ret = EVP_DecryptInit_ex(ctx, EVP_aes_256_cbc(), NULL, key, ivec);
|
||||
break;
|
||||
default:
|
||||
std::cout << "key length is incorrect!" << std::endl;
|
||||
MS_LOG(ERROR) << "key length is incorrect!";
|
||||
ret = -1;
|
||||
}
|
||||
if (ret != 1) {
|
||||
|
@ -158,7 +177,7 @@ int AESEncrypt::evp_aes_decrypt(const unsigned char *encrypt_data, const int len
|
|||
ret = EVP_DecryptInit_ex(ctx, EVP_aes_256_ctr(), NULL, key, ivec);
|
||||
break;
|
||||
default:
|
||||
std::cout << "key length is incorrect!" << std::endl;
|
||||
MS_LOG(ERROR) << "key length is incorrect!";
|
||||
ret = -1;
|
||||
}
|
||||
} else {
|
||||
|
@ -182,5 +201,7 @@ int AESEncrypt::evp_aes_decrypt(const unsigned char *encrypt_data, const int len
|
|||
|
||||
return 0;
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace armour
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -17,9 +17,10 @@
|
|||
#ifndef MINDSPORE_ARMOUR_ENCRYPT_H
|
||||
#define MINDSPORE_ARMOUR_ENCRYPT_H
|
||||
|
||||
#ifndef _WIN32
|
||||
#include <openssl/evp.h>
|
||||
#include <assert.h>
|
||||
#include <iostream>
|
||||
#endif
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
#define INIT_VEC_SIZE 16
|
||||
|
||||
|
@ -33,8 +34,6 @@ enum AES_MODE {
|
|||
class SymmetricEncrypt : Encrypt {};
|
||||
|
||||
class AESEncrypt : SymmetricEncrypt {
|
||||
// use openssl EVP_aes_256_cbc/EVP_aes_128_ctr
|
||||
// hash input key to fixed-length (128/256 bits) using md5/SHA-256
|
||||
public:
|
||||
AESEncrypt(const unsigned char *key, int key_len, unsigned char *ivec, int ivec_len, AES_MODE mode);
|
||||
~AESEncrypt();
|
||||
|
@ -47,9 +46,6 @@ class AESEncrypt : SymmetricEncrypt {
|
|||
unsigned char *iVec;
|
||||
int iVecLen;
|
||||
AES_MODE aesMode;
|
||||
// int evp_aes_cbc_encrypt(const unsigned char* data, const int len, const unsigned char* key, unsigned char* ivec,
|
||||
// unsigned char* encrypt_data, int& encrypt_len); int evp_aes_cbc_decrypt(const unsigned char* encrypt_data, const
|
||||
// int len, const unsigned char* key, unsigned char* ivec, unsigned char* decrypt_data, int& decrypt_len);
|
||||
int evp_aes_encrypt(const unsigned char *data, const int len, const unsigned char *key, unsigned char *ivec,
|
||||
unsigned char *encrypt_data, int *encrypt_len);
|
||||
int evp_aes_decrypt(const unsigned char *encrypt_data, const int len, const unsigned char *key, unsigned char *ivec,
|
||||
|
|
|
@ -15,10 +15,37 @@
|
|||
*/
|
||||
|
||||
#include "armour/secure_protocol/key_agreement.h"
|
||||
#include <openssl/evp.h>
|
||||
|
||||
namespace mindspore {
|
||||
namespace armour {
|
||||
#ifdef _WIN32
|
||||
PrivateKey *KeyAgreement::GeneratePrivKey() {
|
||||
MS_LOG(ERROR) << "Unsupported feature in Windows platform.";
|
||||
return NULL;
|
||||
}
|
||||
|
||||
PublicKey *KeyAgreement::GeneratePubKey(PrivateKey *privKey) {
|
||||
MS_LOG(ERROR) << "Unsupported feature in Windows platform.";
|
||||
return NULL;
|
||||
}
|
||||
|
||||
PrivateKey *KeyAgreement::FromPrivateBytes(unsigned char *data, int len) {
|
||||
MS_LOG(ERROR) << "Unsupported feature in Windows platform.";
|
||||
return NULL;
|
||||
}
|
||||
|
||||
PublicKey *KeyAgreement::FromPublicBytes(unsigned char *data, int len) {
|
||||
MS_LOG(ERROR) << "Unsupported feature in Windows platform.";
|
||||
return NULL;
|
||||
}
|
||||
|
||||
int KeyAgreement::ComputeSharedKey(PrivateKey *privKey, PublicKey *peerPublicKey, int key_len,
|
||||
const unsigned char *salt, int salt_len, unsigned char *exchangeKey) {
|
||||
MS_LOG(ERROR) << "Unsupported feature in Windows platform.";
|
||||
return -1;
|
||||
}
|
||||
|
||||
#else
|
||||
PublicKey::PublicKey(EVP_PKEY *evpKey) { evpPubKey = evpKey; }
|
||||
|
||||
PublicKey::~PublicKey() { EVP_PKEY_free(evpPubKey); }
|
||||
|
@ -47,31 +74,31 @@ int PrivateKey::Exchange(PublicKey *peerPublicKey, int key_len, const unsigned c
|
|||
size_t len = 0;
|
||||
ctx = EVP_PKEY_CTX_new(evpPrivKey, NULL);
|
||||
if (!ctx) {
|
||||
std::cout << "EVP_PKEY_CTX_new failed!" << std::endl;
|
||||
MS_LOG(ERROR) << "EVP_PKEY_CTX_new failed!";
|
||||
return -1;
|
||||
}
|
||||
if (EVP_PKEY_derive_init(ctx) <= 0) {
|
||||
std::cout << "EVP_PKEY_derive_init failed!" << std::endl;
|
||||
MS_LOG(ERROR) << "EVP_PKEY_derive_init failed!";
|
||||
return -1;
|
||||
}
|
||||
if (EVP_PKEY_derive_set_peer(ctx, peerPublicKey->evpPubKey) <= 0) {
|
||||
std::cout << "EVP_PKEY_derive_set_peer failed!" << std::endl;
|
||||
MS_LOG(ERROR) << "EVP_PKEY_derive_set_peer failed!";
|
||||
return -1;
|
||||
}
|
||||
unsigned char *secret;
|
||||
if (EVP_PKEY_derive(ctx, NULL, &len) <= 0) {
|
||||
std::cout << "get derive key size failed!" << std::endl;
|
||||
MS_LOG(ERROR) << "get derive key size failed!";
|
||||
return -1;
|
||||
}
|
||||
|
||||
secret = (unsigned char *)OPENSSL_malloc(len);
|
||||
if (!secret) {
|
||||
std::cout << "malloc secret memory failed!" << std::endl;
|
||||
MS_LOG(ERROR) << "malloc secret memory failed!";
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (EVP_PKEY_derive(ctx, secret, &len) <= 0) {
|
||||
std::cout << "derive key failed!" << std::endl;
|
||||
MS_LOG(ERROR) << "derive key failed!";
|
||||
return -1;
|
||||
}
|
||||
|
||||
|
@ -97,7 +124,6 @@ PrivateKey *KeyAgreement::GeneratePrivKey() {
|
|||
return NULL;
|
||||
}
|
||||
EVP_PKEY_CTX_free(pctx);
|
||||
// PEM_write_PrivateKey(stdout, evpKey, NULL, NULL, 0, NULL, NULL);
|
||||
PrivateKey *privKey = new PrivateKey(evpKey);
|
||||
return privKey;
|
||||
}
|
||||
|
@ -130,7 +156,7 @@ PrivateKey *KeyAgreement::FromPrivateBytes(unsigned char *data, int len) {
|
|||
PublicKey *KeyAgreement::FromPublicBytes(unsigned char *data, int len) {
|
||||
EVP_PKEY *evp_pubKey = EVP_PKEY_new_raw_public_key(EVP_PKEY_X25519, NULL, data, len);
|
||||
if (evp_pubKey == NULL) {
|
||||
std::cout << "create evp_pubKey from raw bytes fail" << std::endl;
|
||||
MS_LOG(ERROR) << "create evp_pubKey from raw bytes fail";
|
||||
return NULL;
|
||||
}
|
||||
PublicKey *pubKey = new PublicKey(evp_pubKey);
|
||||
|
@ -141,6 +167,7 @@ int KeyAgreement::ComputeSharedKey(PrivateKey *privKey, PublicKey *peerPublicKey
|
|||
const unsigned char *salt, int salt_len, unsigned char *exchangeKey) {
|
||||
return privKey->Exchange(peerPublicKey, key_len, salt, salt_len, exchangeKey);
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace armour
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -17,16 +17,23 @@
|
|||
#ifndef MINDSPORE_KEY_AGREEMENT_H
|
||||
#define MINDSPORE_KEY_AGREEMENT_H
|
||||
|
||||
#ifndef _WIN32
|
||||
#include <openssl/dh.h>
|
||||
#include <openssl/pem.h>
|
||||
#include <openssl/evp.h>
|
||||
#include <iostream>
|
||||
#endif
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
#define KEK_KEY_LEN 32
|
||||
#define ITERATION 10000
|
||||
|
||||
namespace mindspore {
|
||||
namespace armour {
|
||||
|
||||
#ifdef _WIN32
|
||||
class PublicKey {};
|
||||
class PrivateKey {};
|
||||
#else
|
||||
class PublicKey {
|
||||
public:
|
||||
explicit PublicKey(EVP_PKEY *evpKey);
|
||||
|
@ -44,6 +51,7 @@ class PrivateKey {
|
|||
int GetPublicBytes(size_t *len, unsigned char *pubKeyBytes);
|
||||
EVP_PKEY *evpPrivKey;
|
||||
};
|
||||
#endif
|
||||
|
||||
class KeyAgreement {
|
||||
public:
|
||||
|
|
|
@ -15,18 +15,13 @@
|
|||
*/
|
||||
|
||||
#include "armour/secure_protocol/random.h"
|
||||
#include <vector>
|
||||
|
||||
namespace mindspore {
|
||||
namespace armour {
|
||||
Random::Random(size_t init_seed) { generator.seed(init_seed); }
|
||||
|
||||
Random::~Random() {}
|
||||
|
||||
int Random::GetRandomBytes(unsigned char *secret, int num_bytes) {
|
||||
int retval = RAND_priv_bytes(secret, RANDOM_LEN);
|
||||
return retval;
|
||||
}
|
||||
|
||||
void Random::RandUniform(float *array, int size) {
|
||||
std::uniform_real_distribution<double> rand(0, 1);
|
||||
for (int i = 0; i < size; i++) {
|
||||
|
@ -41,6 +36,23 @@ void Random::RandNorminal(float *array, int size) {
|
|||
}
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
int Random::GetRandomBytes(unsigned char *secret, int num_bytes) {
|
||||
MS_LOG(ERROR) << "Unsupported feature in Windows platform.";
|
||||
return -1;
|
||||
}
|
||||
|
||||
int Random::RandomAESCTR(std::vector<float> *noise, int noise_len, const unsigned char *seed, int seed_len) {
|
||||
MS_LOG(ERROR) << "Unsupported feature in Windows platform.";
|
||||
return -1;
|
||||
}
|
||||
|
||||
#else
|
||||
int Random::GetRandomBytes(unsigned char *secret, int num_bytes) {
|
||||
int retval = RAND_priv_bytes(secret, RANDOM_LEN);
|
||||
return retval;
|
||||
}
|
||||
|
||||
int Random::RandomAESCTR(std::vector<float> *noise, int noise_len, const unsigned char *seed, int seed_len) {
|
||||
if (seed_len != 16 && seed_len != 32) {
|
||||
std::cout << "seed length must be 16 or 32!" << std::endl;
|
||||
|
@ -70,5 +82,7 @@ int Random::RandomAESCTR(std::vector<float> *noise, int noise_len, const unsigne
|
|||
}
|
||||
return 0;
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace armour
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,10 +16,12 @@
|
|||
|
||||
#ifndef MINDSPORE_ARMOUR_RANDOM_H
|
||||
#define MINDSPORE_ARMOUR_RANDOM_H
|
||||
#include <openssl/rand.h>
|
||||
|
||||
#include <random>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#ifndef _WIN32
|
||||
#include <openssl/rand.h>
|
||||
#endif
|
||||
#include "armour/secure_protocol/encrypt.h"
|
||||
namespace mindspore {
|
||||
namespace armour {
|
||||
|
|
|
@ -13,8 +13,9 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "armour/secure_protocol/secret_sharing.h"
|
||||
#include <string>
|
||||
|
||||
namespace mindspore {
|
||||
namespace armour {
|
||||
void secure_zero(unsigned char *s, size_t n) {
|
||||
|
@ -23,12 +24,13 @@ void secure_zero(unsigned char *s, size_t n) {
|
|||
while (n--) *p++ = '\0';
|
||||
}
|
||||
|
||||
#ifndef _WIN32
|
||||
int GetRandInteger(mpz_t x, mpz_t prim) {
|
||||
size_t bytes_len = (mpz_sizeinbase(prim, 2) + 8 - 1) / 8;
|
||||
unsigned char buf[bytes_len];
|
||||
while (true) {
|
||||
if (!RAND_bytes(buf, bytes_len)) {
|
||||
std::cout << "Get Rand Integer failed!" << std::endl;
|
||||
MS_LOG(WARNING) << "Get Rand Integer failed!";
|
||||
continue;
|
||||
}
|
||||
mpz_import(x, bytes_len, 1, 1, 0, 0, buf);
|
||||
|
@ -45,7 +47,7 @@ int GetRandomPrime(mpz_t prim) {
|
|||
const int max_prime_len = SECRET_MAX_LEN + 1;
|
||||
unsigned char buf[max_prime_len];
|
||||
if (!RAND_bytes(buf, max_prime_len)) {
|
||||
std::cout << "Get Rand Integer failed!" << std::endl;
|
||||
MS_LOG(ERROR) << "Get Rand Integer failed!";
|
||||
return -1;
|
||||
}
|
||||
mpz_import(rand, max_prime_len, 1, 1, 0, 0, buf);
|
||||
|
@ -58,7 +60,7 @@ int GetRandomPrime(mpz_t prim) {
|
|||
void PrintBigInteger(mpz_t x) {
|
||||
char *tmp = mpz_get_str(NULL, 16, x);
|
||||
std::string Str = tmp;
|
||||
std::cout << "*************************" << Str << std::endl;
|
||||
MS_LOG(INFO) << Str;
|
||||
void (*freefunc)(void *, size_t);
|
||||
mp_get_memory_functions(NULL, NULL, &freefunc);
|
||||
freefunc(tmp, strlen(tmp) + 1);
|
||||
|
@ -68,7 +70,7 @@ void PrintBigInteger(mpz_t x, int hex) {
|
|||
char *tmp = mpz_get_str(NULL, hex, x);
|
||||
std::string Str = tmp;
|
||||
|
||||
std::cout << Str << std::endl;
|
||||
MS_LOG(INFO) << Str;
|
||||
void (*freefunc)(void *, size_t);
|
||||
mp_get_memory_functions(NULL, NULL, &freefunc);
|
||||
freefunc(tmp, strlen(tmp) + 1);
|
||||
|
@ -118,10 +120,10 @@ int SecretSharing::CalculateShares(const mpz_t coeff[], int k, int n, const std:
|
|||
shares[i]->data = (unsigned char *)malloc(share_len + 1);
|
||||
mpz_export(shares[i]->data, &(shares[i]->len), 1, 1, 0, 0, y);
|
||||
if (shares[i]->len != share_len) {
|
||||
std::cout << "share_len is not equal" << std::endl;
|
||||
MS_LOG(ERROR) << "share_len is not equal";
|
||||
return -1;
|
||||
}
|
||||
std::cout << "share_" << i + 1 << ": ";
|
||||
MS_LOG(INFO) << "share_" << i + 1 << ": ";
|
||||
PrintBigInteger(y);
|
||||
}
|
||||
mpz_clear(x);
|
||||
|
@ -132,11 +134,11 @@ int SecretSharing::CalculateShares(const mpz_t coeff[], int k, int n, const std:
|
|||
int SecretSharing::Split(int n, const int k, const char *secret, const size_t length,
|
||||
const std::vector<Share *> &shares) {
|
||||
if (k <= 1 || k > n) {
|
||||
std::cout << "invalid parameters" << std::endl;
|
||||
MS_LOG(ERROR) << "invalid parameters";
|
||||
return -1;
|
||||
}
|
||||
if (static_cast<int>(shares.size()) != n) {
|
||||
std::cout << "the size of shares must be equal to nq" << std::endl;
|
||||
MS_LOG(ERROR) << "the size of shares must be equal to n";
|
||||
return -1;
|
||||
}
|
||||
this->degree_ = length * 8;
|
||||
|
@ -150,10 +152,15 @@ int SecretSharing::Split(int n, const int k, const char *secret, const size_t le
|
|||
for (; i < k && ret == 0; i++) {
|
||||
mpz_init(coeff[i]);
|
||||
ret = GetRandInteger(coeff[i], this->prim_);
|
||||
std::cout << "coeff_" << i << ":";
|
||||
if (ret != 0) {
|
||||
break;
|
||||
}
|
||||
MS_LOG(INFO) << "coeff_" << i << ":";
|
||||
PrintBigInteger(coeff[i]);
|
||||
}
|
||||
if (ret == 0) ret = CalculateShares(coeff, k, n, shares);
|
||||
if (ret == 0) {
|
||||
ret = CalculateShares(coeff, k, n, shares);
|
||||
}
|
||||
for (i = 0; i < k; i++) mpz_clear(coeff[i]);
|
||||
return ret;
|
||||
}
|
||||
|
@ -174,10 +181,10 @@ int SecretSharing::Combine(int k, const std::vector<Share *> &shares, char *secr
|
|||
mpz_init(denses[i]);
|
||||
mpz_init(nums[i]);
|
||||
GetShare(x[i], y[i], shares[i]);
|
||||
std::cout << "combine -- share_" << mpz_get_str(NULL, 10, x[i]) << ": ";
|
||||
MS_LOG(INFO) << "combine -- share_" << mpz_get_str(NULL, 10, x[i]) << ": ";
|
||||
PrintBigInteger(y[i]);
|
||||
printf("index is : %u\n", shares[i]->index);
|
||||
printf("len is %zu.\n", shares[i]->len);
|
||||
MS_LOG(INFO) << "index is : " << shares[i]->index;
|
||||
MS_LOG(INFO) << "len is %zu " << shares[i]->len;
|
||||
}
|
||||
|
||||
mpz_t sum;
|
||||
|
@ -215,6 +222,6 @@ int SecretSharing::Combine(int k, const std::vector<Share *> &shares, char *secr
|
|||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
#endif
|
||||
} // namespace armour
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -13,16 +13,16 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_SECRET_SHARING_H
|
||||
#define MINDSPORE_SECRET_SHARING_H
|
||||
#include <string.h>
|
||||
#include <unistd.h>
|
||||
#include <fcntl.h>
|
||||
#include <assert.h>
|
||||
#ifndef _WIN32
|
||||
#include <gmp.h>
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include "openssl/rand.h"
|
||||
#endif
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace armour {
|
||||
|
@ -36,6 +36,7 @@ struct Share {
|
|||
~Share();
|
||||
};
|
||||
|
||||
#ifndef _WIN32
|
||||
void secure_zero(void *s, size_t);
|
||||
int GetRandInteger(mpz_t x, mpz_t prim);
|
||||
int GetRandomPrime(mpz_t prim);
|
||||
|
@ -67,6 +68,8 @@ class SecretSharing {
|
|||
// convert secret sharing from Share type to mpz_t type
|
||||
void GetShare(mpz_t x, mpz_t share, Share *s_share);
|
||||
};
|
||||
#endif
|
||||
|
||||
} // namespace armour
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_SECRET_SHARING_H
|
||||
|
|
|
@ -58,6 +58,12 @@ if(NOT ENABLE_CPU OR WIN32)
|
|||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/get_model_kernel.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/pull_weight_kernel.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/push_weight_kernel.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/client_list_kernel.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/exchange_keys_kernel.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/get_keys_kernel.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/get_secrets_kernel.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/reconstruct_secrets_kernel.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/share_secrets_kernel.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/params_info.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/consistent_hash_ring.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "server/iteration_timer.cc")
|
||||
|
|
Loading…
Reference in New Issue