fl compression

This commit is contained in:
wtcheng 2022-03-16 14:32:22 +08:00
parent 0373d2f915
commit ce17db99a6
34 changed files with 1715 additions and 123 deletions

View File

@ -51,6 +51,8 @@ if(NOT ENABLE_CPU OR WIN32)
list(REMOVE_ITEM _FL_SRC_FILES "armour/cipher/cipher_reconstruct.cc")
list(REMOVE_ITEM _FL_SRC_FILES "armour/cipher/cipher_shares.cc")
list(REMOVE_ITEM _FL_SRC_FILES "armour/cipher/cipher_unmask.cc")
list(REMOVE_ITEM _FL_SRC_FILES "compression/decode_executor.cc")
list(REMOVE_ITEM _FL_SRC_FILES "compression/encode_executor.cc")
endif()
if(CMAKE_SYSTEM_NAME MATCHES "Darwin")

View File

@ -0,0 +1,150 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fl/compression/decode_executor.h"
namespace mindspore {
namespace fl {
namespace compression {
std::vector<int> DecodeExecutor::ConstructMaskArray(int seed, float upload_sparse_rate, size_t param_num) {
static int multiplier = 2147483647;
static double increment = 4294967294.0;
static int modulo = 48271;
size_t retain_num = size_t(static_cast<float>(param_num) * upload_sparse_rate);
if (retain_num == 0) {
MS_LOG(WARNING) << "The retain_num is 0, and upload_sparse_rate is too small.";
}
std::vector<int> mask_array(param_num, 0);
for (size_t i = 0; i < retain_num; ++i) {
mask_array[i] = 1;
}
seed = ((seed + multiplier) * modulo) % multiplier;
for (size_t i = 0; i < param_num; ++i) {
// generate random number in (0, 1)
double rand = static_cast<double>(seed) / increment + 0.5;
// update seed
seed = (seed * modulo) % multiplier;
size_t j = size_t(rand * static_cast<double>(param_num - i)) + i;
int temp = mask_array[i];
mask_array[i] = mask_array[j];
mask_array[j] = temp;
}
return mask_array;
}
bool DecodeExecutor::DeQuantSparseDiff(std::map<std::string, std::vector<float>> *weight_map,
const std::vector<CompressFeatureMap> &compress_feature_maps, size_t num_bits,
float upload_sparse_rate, int seed, const std::vector<std::string> &name_vec,
size_t data_size) {
std::vector<std::vector<float>> decompress_feature_maps;
// origin parameters
std::vector<size_t> shape_vec;
size_t param_num = 0;
const auto &iter_to_model = mindspore::fl::server::ModelStore::GetInstance().iteration_to_model();
size_t latest_iter_num = iter_to_model.rbegin()->first;
std::map<std::string, AddressPtr> feature_maps =
mindspore::fl::server::ModelStore::GetInstance().GetModelByIterNum(latest_iter_num);
// get shape vector and number of upload parameters
for (const auto &name : name_vec) {
size_t shape = feature_maps[name]->size / sizeof(float);
shape_vec.emplace_back(shape);
param_num += shape;
}
MS_LOG(DEBUG) << "Compression get last weights success!";
// quant decode
auto temp1 = static_cast<float>(1 << num_bits) - 1.0f;
auto temp2 = static_cast<float>(1 << (num_bits - 1));
std::vector<float> de_min_max_feature_map;
for (auto compress_feature_map : compress_feature_maps) {
float min_val = compress_feature_map.min_val;
float max_val = compress_feature_map.max_val;
float scale_val = static_cast<float>(max_val - min_val) / temp1 + 1e-10f;
size_t size = compress_feature_map.compress_data.size();
for (size_t i = 0; i < size; ++i) {
de_min_max_feature_map.emplace_back(
(static_cast<float>(compress_feature_map.compress_data[i]) + temp2) * scale_val + min_val);
}
}
MS_LOG(DEBUG) << "Compression quant decode success!";
// sparse decode
std::vector<int> mask_array = ConstructMaskArray(seed, upload_sparse_rate, param_num);
size_t index = 0;
size_t de_min_max_feature_map_index = 0;
for (const auto &shape : shape_vec) {
std::vector<float> feature_map(shape);
for (size_t i = 0; i < shape; ++i) {
if (index >= mask_array.size()) {
MS_LOG(WARNING) << "The mask_array and parameter shape is not matched.";
return false;
}
if (mask_array[index] == 1) {
if (de_min_max_feature_map_index >= de_min_max_feature_map.size()) {
MS_LOG(WARNING) << "The number of upload parameters is too small.";
return false;
}
feature_map[i] = de_min_max_feature_map[de_min_max_feature_map_index];
de_min_max_feature_map_index += 1;
} else {
feature_map[i] = 0.0f;
}
index += 1;
}
decompress_feature_maps.emplace_back(feature_map);
}
MS_LOG(DEBUG) << "Compression sparse decode success!";
// difference decode
for (size_t i = 0; i < decompress_feature_maps.size(); ++i) {
size_t feature_size = decompress_feature_maps[i].size();
std::string name = name_vec[i];
float *weight_data = reinterpret_cast<float *>(feature_maps[name]->addr);
auto &weight_item = (*weight_map)[name];
weight_item.resize(feature_size);
for (size_t j = 0; j < feature_size; ++j) {
weight_item[j] = decompress_feature_maps[i][j] + data_size * weight_data[j];
}
}
MS_LOG(DEBUG) << "Compression difference decode success!";
return true;
}
bool DecodeExecutor::Decode(std::map<std::string, std::vector<float>> *weight_map,
const std::vector<CompressFeatureMap> &compress_feature_maps,
schema::CompressType upload_compress_type, float upload_sparse_rate, int seed,
const std::vector<std::string> &name_vec, size_t data_size) {
if (upload_compress_type == schema::CompressType_DIFF_SPARSE_QUANT) {
return DeQuantSparseDiff(weight_map, compress_feature_maps, 8, upload_sparse_rate, seed, name_vec, data_size);
}
return false;
}
schema::CompressType DecodeExecutor::GetCompressType(schema::CompressType upload_compress_type) {
if (upload_compress_type == schema::CompressType_DIFF_SPARSE_QUANT) {
MS_LOG(DEBUG) << "This upload compress type is DIFF_SPARSE_QUANT.";
return schema::CompressType_DIFF_SPARSE_QUANT;
}
MS_LOG(DEBUG) << "This upload compress type is NO_COMPRESS.";
return schema::CompressType_NO_COMPRESS;
}
} // namespace compression
} // namespace fl
} // namespace mindspore

View File

@ -0,0 +1,74 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FL_COMPRESSION_DECODE_EXECUTOR_H_
#define MINDSPORE_CCSRC_FL_COMPRESSION_DECODE_EXECUTOR_H_
#include <memory>
#include <string>
#include <vector>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <functional>
#include <algorithm>
#include <regex>
#include <map>
#include <utility>
#include "proto/comm.pb.h"
#include "schema/fl_job_generated.h"
#include "schema/cipher_generated.h"
#include "fl/server/model_store.h"
#include "fl/server/common.h"
#include "ps/ps_context.h"
namespace mindspore {
namespace fl {
namespace compression {
struct CompressFeatureMap {
std::string weight_fullname;
std::vector<int8_t> compress_data;
float min_val;
float max_val;
};
class DecodeExecutor {
public:
static DecodeExecutor &GetInstance() {
static DecodeExecutor instance;
return instance;
}
// construct mask array for random sparse
std::vector<int> ConstructMaskArray(int seed, float upload_sparse_rate, size_t param_num);
// decode min_max quantization and random sparse and parameter difference
bool DeQuantSparseDiff(std::map<std::string, std::vector<float>> *weight_map,
const std::vector<CompressFeatureMap> &compress_feature_maps, size_t num_bits,
float upload_sparse_rate, int seed, const std::vector<std::string> &name_vec,
size_t data_size);
// decode
bool Decode(std::map<std::string, std::vector<float>> *weight_map,
const std::vector<CompressFeatureMap> &compress_feature_maps, schema::CompressType upload_compress_type,
float upload_sparse_rate, int seed, const std::vector<std::string> &name_vec, size_t data_size);
schema::CompressType GetCompressType(schema::CompressType upload_compress_type);
};
} // namespace compression
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FL_COMPRESSION_DECODE_EXECUTOR_H_

View File

@ -0,0 +1,102 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fl/compression/encode_executor.h"
#include <arpa/inet.h>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <functional>
#include <algorithm>
#include <regex>
#include <map>
#include <utility>
#include <vector>
#include "fl/server/common.h"
namespace mindspore {
namespace fl {
namespace compression {
bool CompressExecutor::EnableCompressWeight(const schema::CompressType compressType) {
return kCompressTypeMap.count(compressType) > 0;
}
bool CompressExecutor::construct_compress_weight(std::map<std::string, CompressWeight> *compressWeights,
std::map<std::string, std::vector<float>> feature_maps,
const schema::CompressType compressType) {
if (compressType == schema::CompressType_QUANT) {
return quant_min_max(compressWeights, feature_maps, kCompressTypeMap.at(compressType));
}
return false;
}
bool CompressExecutor::quant_min_max(std::map<std::string, CompressWeight> *compressWeights,
std::map<std::string, std::vector<float>> feature_maps, size_t num_bits) {
auto temp1 = static_cast<float>(1 << num_bits) - 1.0f;
auto temp2 = static_cast<float>(1 << (num_bits - 1));
for (const auto &feature_map : feature_maps) {
std::string weight_name = feature_map.first;
float min_value = 1e10f;
float max_value = -min_value;
for (const auto &feature : feature_map.second) {
if (feature > max_value) {
max_value = feature;
}
if (feature < min_value) {
min_value = feature;
}
}
float scale_value = (max_value - min_value) / temp1 + 1e-10f;
size_t size = feature_map.second.size();
if (size == 0) {
MS_LOG(WARNING) << "The size of parameters is zero.";
return false;
}
CompressWeight compressWeight;
for (size_t i = 0; i < size; ++i) {
auto round_data = round((feature_map.second[i] - min_value) / scale_value - temp2);
// bit pack can be implemented here in the future
auto int8_data = int8_t(round_data);
compressWeight.compress_data.emplace_back(int8_data);
}
compressWeight.min_val = min_value;
compressWeight.max_val = max_value;
compressWeight.compress_data_len = size;
(*compressWeights)[weight_name] = compressWeight;
}
return true;
}
schema::CompressType CompressExecutor::GetCompressType(const flatbuffers::Vector<int8_t> *download_compress_types) {
schema::CompressType compressType = schema::CompressType_NO_COMPRESS;
if (download_compress_types == nullptr) {
MS_LOG(DEBUG) << "The client does not support current download compress type.";
} else {
for (size_t i = 0; i < download_compress_types->size(); ++i) {
auto download_compress_type = download_compress_types->Get(i);
if (download_compress_type == schema::CompressType_QUANT) {
compressType = schema::CompressType_QUANT;
break;
}
}
}
return compressType;
}
} // namespace compression
} // namespace fl
} // namespace mindspore

View File

@ -0,0 +1,68 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FL_COMPRESSION_ENCODE_EXECUTOR_H_
#define MINDSPORE_CCSRC_FL_COMPRESSION_ENCODE_EXECUTOR_H_
#include <memory>
#include <string>
#include <vector>
#include <map>
#include "proto/comm.pb.h"
#include "schema/fl_job_generated.h"
#include "schema/cipher_generated.h"
#include "fl/armour/secure_protocol/key_agreement.h"
#include "ps/ps_context.h"
#include "ps/core/worker_node.h"
#include "ps/core/cluster_metadata.h"
#include "ps/core/communicator/tcp_communicator.h"
#include "fl/server/common.h"
namespace mindspore {
namespace fl {
namespace compression {
// compress type map: schema::CompressType -> num bits
const std::map<schema::CompressType, size_t> kCompressTypeMap = {{schema::CompressType_QUANT, 8}};
struct CompressWeight {
std::vector<int8_t> compress_data;
size_t compress_data_len;
float min_val;
float max_val;
};
class CompressExecutor {
public:
static CompressExecutor &GetInstance() {
static CompressExecutor instance;
return instance;
}
bool EnableCompressWeight(const schema::CompressType compressType);
bool construct_compress_weight(std::map<std::string, CompressWeight> *compressWeights,
std::map<std::string, std::vector<float>> feature_maps,
const schema::CompressType compressType);
bool quant_min_max(std::map<std::string, CompressWeight> *compressWeights,
std::map<std::string, std::vector<float>> feature_maps, size_t num_bits);
schema::CompressType GetCompressType(const flatbuffers::Vector<int8_t> *download_compress_types);
};
} // namespace compression
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FL_COMPRESSION_ENCODE_EXECUTOR_H_

View File

@ -149,6 +149,11 @@ constexpr auto kUpdateModelRejectClientNum = "updateModelRejectClientNum";
constexpr auto kGetModelTotalClientNum = "getModelTotalClientNum";
constexpr auto kGetModelAcceptClientNum = "getModelAcceptClientNum";
constexpr auto kGetModelRejectClientNum = "getModelRejectClientNum";
constexpr auto kMinVal = "min_val";
constexpr auto kMaxVal = "max_val";
constexpr auto kQuant = "QUANT";
constexpr auto kDiffSparseQuant = "DIFF_SPARSE_QUANT";
constexpr auto kNoCompress = "NO_COMPRESS";
// OptimParamNameToIndex represents every inputs/workspace/outputs parameter's offset when an optimizer kernel is
// launched.

View File

@ -588,6 +588,7 @@ void Iteration::Next(bool is_iteration_valid, const std::string &reason) {
if (LocalMetaStore::GetInstance().verifyAggregationFeatureMap(feature_map)) {
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
ModelStore::GetInstance().StoreCompressModelByIterNum(iteration_num_, model);
iteration_result_ = IterationResult::kSuccess;
MS_LOG(INFO) << "Iteration " << iteration_num_ << " is successfully finished.";
} else {
@ -599,6 +600,7 @@ void Iteration::Next(bool is_iteration_valid, const std::string &reason) {
size_t latest_iter_num = iter_to_model.rbegin()->first;
const auto &model = ModelStore::GetInstance().GetModelByIterNum(latest_iter_num);
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
ModelStore::GetInstance().StoreCompressModelByIterNum(iteration_num_, model);
iteration_result_ = IterationResult::kFail;
MS_LOG(WARNING) << "Iteration " << iteration_num_ << " is invalid. Reason: " << reason;
}

View File

@ -92,7 +92,7 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req,
return;
}
auto next_req_time = LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp);
std::map<std::string, AddressPtr> feature_maps;
std::map<std::string, AddressPtr> feature_maps = {};
size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num();
size_t get_model_iter = IntToSize(get_model_req->iteration());
const auto &iter_to_model = ModelStore::GetInstance().iteration_to_model();
@ -110,6 +110,7 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req,
SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize());
return;
}
IncreaseAcceptClientNum();
auto real_get_model_iter = get_model_iter;
if (iter_to_model.count(get_model_iter) == 0) {
@ -118,12 +119,37 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req,
<< " is invalid. Current iteration is " << std::to_string(current_iter);
real_get_model_iter = latest_iter_num;
}
auto cache = ModelStore::GetInstance().GetModelResponseCache(name_, current_iter, real_get_model_iter);
auto download_compress_types = get_model_req->download_compress_types();
schema::CompressType compressType =
mindspore::fl::compression::CompressExecutor::GetInstance().GetCompressType(download_compress_types);
std::string compress_type;
if (compressType == schema::CompressType_QUANT) {
compress_type = kQuant;
} else {
compress_type = kNoCompress;
}
auto cache = ModelStore::GetInstance().GetModelResponseCache(name_, current_iter, real_get_model_iter, compress_type);
if (cache == nullptr) {
feature_maps = ModelStore::GetInstance().GetModelByIterNum(real_get_model_iter);
// Only download compress weights if client support.
std::map<std::string, AddressPtr> compress_feature_maps = {};
if (compressType == schema::CompressType_NO_COMPRESS) {
feature_maps = ModelStore::GetInstance().GetModelByIterNum(real_get_model_iter);
} else {
auto compressExecutor = mindspore::fl::compression::CompressExecutor::GetInstance();
if (compressExecutor.EnableCompressWeight(compressType)) {
const auto &iter_to_compress_model = ModelStore::GetInstance().iteration_to_compress_model();
if (iter_to_compress_model.count(get_model_iter) == 0) {
MS_LOG(DEBUG) << "The iteration of GetCompressModel request " << std::to_string(get_model_iter)
<< " is invalid. Current iteration is " << std::to_string(current_iter);
compress_feature_maps = ModelStore::GetInstance().GetCompressModelByIterNum(latest_iter_num, compressType);
} else {
compress_feature_maps = ModelStore::GetInstance().GetCompressModelByIterNum(get_model_iter, compressType);
}
}
}
BuildGetModelRsp(fbb, schema::ResponseCode_SUCCEED, "Get model for iteration " + std::to_string(get_model_iter),
current_iter, feature_maps, std::to_string(next_req_time));
cache = ModelStore::GetInstance().StoreModelResponseCache(name_, current_iter, real_get_model_iter,
current_iter, feature_maps, std::to_string(next_req_time), compressType, compress_feature_maps);
cache = ModelStore::GetInstance().StoreModelResponseCache(name_, current_iter, real_get_model_iter, compress_type,
fbb->GetBufferPointer(), fbb->GetSize());
if (cache == nullptr) {
SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize());
@ -131,7 +157,7 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req,
}
}
SendResponseMsgInference(message, cache->data(), cache->size(), ModelStore::GetInstance().RelModelResponseCache);
MS_LOG(DEBUG) << "GetModel last iteratin is valid or not: " << Iteration::GetInstance().is_last_iteration_valid()
MS_LOG(DEBUG) << "GetModel last iteration is valid or not: " << Iteration::GetInstance().is_last_iteration_valid()
<< ", next request time is " << next_req_time << ", current iteration is " << current_iter;
return;
}
@ -139,7 +165,8 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req,
void GetModelKernel::BuildGetModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
const std::string &reason, const size_t iter,
const std::map<std::string, AddressPtr> &feature_maps,
const std::string &timestamp) {
const std::string &timestamp, const schema::CompressType &compressType,
const std::map<std::string, AddressPtr> &compress_feature_maps) {
if (fbb == nullptr) {
MS_LOG(ERROR) << "Input fbb is nullptr.";
return;
@ -156,12 +183,40 @@ void GetModelKernel::BuildGetModelRsp(const std::shared_ptr<FBBuilder> &fbb, con
}
auto fbs_feature_maps_vector = fbb->CreateVector(fbs_feature_maps);
// construct compress feature maps with fbs
std::vector<flatbuffers::Offset<schema::CompressFeatureMap>> fbs_compress_feature_maps;
for (const auto &compress_feature_map : compress_feature_maps) {
if (compress_feature_map.first.find(kMinVal) != string::npos ||
compress_feature_map.first.find(kMaxVal) != string::npos) {
continue;
}
auto fbs_compress_weight_fullname = fbb->CreateString(compress_feature_map.first);
auto fbs_compress_weight_data = fbb->CreateVector(reinterpret_cast<int8_t *>(compress_feature_map.second->addr),
compress_feature_map.second->size / sizeof(int8_t));
const std::string min_val_name = compress_feature_map.first + "." + kMinVal;
const std::string max_val_name = compress_feature_map.first + "." + kMaxVal;
const AddressPtr min_val_ptr = compress_feature_maps.at(min_val_name);
const AddressPtr max_val_ptr = compress_feature_maps.at(max_val_name);
float *fbs_min_val_ptr = reinterpret_cast<float *>(min_val_ptr->addr);
float *fbs_max_val_ptr = reinterpret_cast<float *>(max_val_ptr->addr);
auto fbs_compress_feature_map = schema::CreateCompressFeatureMap(
*(fbb.get()), fbs_compress_weight_fullname, fbs_compress_weight_data, *fbs_min_val_ptr, *fbs_max_val_ptr);
fbs_compress_feature_maps.push_back(fbs_compress_feature_map);
}
auto fbs_compress_feature_maps_vector = fbb->CreateVector(fbs_compress_feature_maps);
schema::ResponseGetModelBuilder rsp_get_model_builder(*(fbb.get()));
rsp_get_model_builder.add_retcode(static_cast<int>(retcode));
rsp_get_model_builder.add_reason(fbs_reason);
rsp_get_model_builder.add_iteration(static_cast<int>(iter));
rsp_get_model_builder.add_feature_map(fbs_feature_maps_vector);
rsp_get_model_builder.add_timestamp(fbs_timestamp);
rsp_get_model_builder.add_download_compress_type(compressType);
rsp_get_model_builder.add_compress_feature_map(fbs_compress_feature_maps_vector);
auto rsp_get_model = rsp_get_model_builder.Finish();
fbb->Finish(rsp_get_model);
return;

View File

@ -25,6 +25,7 @@
#include "fl/server/executor.h"
#include "fl/server/kernel/round/round_kernel.h"
#include "fl/server/kernel/round/round_kernel_factory.h"
#include "fl/compression/encode_executor.h"
namespace mindspore {
namespace fl {
@ -44,7 +45,9 @@ class GetModelKernel : public RoundKernel {
void GetModel(const schema::RequestGetModel *get_model_req, const std::shared_ptr<ps::core::MessageHandler> &message);
void BuildGetModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
const std::string &reason, const size_t iter,
const std::map<std::string, AddressPtr> &feature_maps, const std::string &timestamp);
const std::map<std::string, AddressPtr> &feature_maps, const std::string &timestamp,
const schema::CompressType &compressType = schema::CompressType_NO_COMPRESS,
const std::map<std::string, AddressPtr> &compress_feature_maps = {});
// The executor is for getting model for getModel request.
Executor *executor_;

View File

@ -126,10 +126,19 @@ bool StartFLJobKernel::Launch(const uint8_t *req_data, size_t len,
IncreaseAcceptClientNum();
auto curr_iter_num = LocalMetaStore::GetInstance().curr_iter_num();
auto last_iteration = curr_iter_num - 1;
auto cache = ModelStore::GetInstance().GetModelResponseCache(name_, curr_iter_num, last_iteration);
auto download_compress_types = start_fl_job_req->download_compress_types();
schema::CompressType compressType =
mindspore::fl::compression::CompressExecutor::GetInstance().GetCompressType(download_compress_types);
std::string compress_type;
if (compressType == schema::CompressType_QUANT) {
compress_type = kQuant;
} else {
compress_type = kNoCompress;
}
auto cache = ModelStore::GetInstance().GetModelResponseCache(name_, curr_iter_num, last_iteration, compress_type);
if (cache == nullptr) {
StartFLJob(fbb);
cache = ModelStore::GetInstance().StoreModelResponseCache(name_, curr_iter_num, last_iteration,
StartFLJob(fbb, device_meta, start_fl_job_req);
cache = ModelStore::GetInstance().StoreModelResponseCache(name_, curr_iter_num, last_iteration, compress_type,
fbb->GetBufferPointer(), fbb->GetSize());
if (cache == nullptr) {
SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize());
@ -303,22 +312,40 @@ ResultCode StartFLJobKernel::CountForStartFLJob(const std::shared_ptr<FBBuilder>
return ResultCode::kSuccess;
}
void StartFLJobKernel::StartFLJob(const std::shared_ptr<FBBuilder> &fbb) {
void StartFLJobKernel::StartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &,
const schema::RequestFLJob *start_fl_job_req) {
size_t last_iteration = LocalMetaStore::GetInstance().curr_iter_num() - 1;
auto feature_maps = ModelStore::GetInstance().GetModelByIterNum(last_iteration);
if (feature_maps.empty()) {
MS_LOG(WARNING) << "The feature map for startFLJob is empty.";
std::map<std::string, AddressPtr> feature_maps = {};
std::map<std::string, AddressPtr> compress_feature_maps = {};
// Only download compress weights if client support.
auto download_compress_types = start_fl_job_req->download_compress_types();
schema::CompressType compressType =
mindspore::fl::compression::CompressExecutor::GetInstance().GetCompressType(download_compress_types);
if (compressType == schema::CompressType_NO_COMPRESS) {
feature_maps = ModelStore::GetInstance().GetModelByIterNum(last_iteration);
if (feature_maps.empty()) {
MS_LOG(WARNING) << "The feature map for startFLJob is empty.";
}
} else {
if (mindspore::fl::compression::CompressExecutor::GetInstance().EnableCompressWeight(compressType)) {
compress_feature_maps = ModelStore::GetInstance().GetCompressModelByIterNum(last_iteration, compressType);
}
}
BuildStartFLJobRsp(fbb, schema::ResponseCode_SUCCEED, "success", true,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)),
feature_maps);
feature_maps, compressType, compress_feature_maps);
return;
}
void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
const std::string &reason, const bool is_selected,
const std::string &next_req_time,
std::map<std::string, AddressPtr> feature_maps) {
const std::map<std::string, AddressPtr> &feature_maps,
const schema::CompressType &compressType,
const std::map<std::string, AddressPtr> &compress_feature_maps) {
if (fbb == nullptr) {
MS_LOG(WARNING) << "Input fbb is nullptr.";
return;
@ -350,6 +377,12 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
auto cipher_public_params =
schema::CreateCipherPublicParams(*fbb.get(), encrypt_type, pw_params, dp_params, ds_params);
#endif
schema::CompressType upload_compress_type;
if (ps::PSContext::instance()->upload_compress_type() == kDiffSparseQuant) {
upload_compress_type = schema::CompressType_DIFF_SPARSE_QUANT;
} else {
upload_compress_type = schema::CompressType_NO_COMPRESS;
}
schema::FLPlanBuilder fl_plan_builder(*(fbb.get()));
fl_plan_builder.add_fl_name(fbs_fl_name);
@ -375,6 +408,33 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
}
auto fbs_feature_maps_vector = fbb->CreateVector(fbs_feature_maps);
// construct compress feature maps with fbs
std::vector<flatbuffers::Offset<schema::CompressFeatureMap>> fbs_compress_feature_maps;
for (const auto &compress_feature_map : compress_feature_maps) {
if (compressType == schema::CompressType_QUANT) {
if (compress_feature_map.first.find(kMinVal) != string::npos ||
compress_feature_map.first.find(kMaxVal) != string::npos) {
continue;
}
auto fbs_compress_weight_fullname = fbb->CreateString(compress_feature_map.first);
auto fbs_compress_weight_data = fbb->CreateVector(reinterpret_cast<int8_t *>(compress_feature_map.second->addr),
compress_feature_map.second->size / sizeof(int8_t));
const std::string min_val_name = compress_feature_map.first + "." + kMinVal;
const std::string max_val_name = compress_feature_map.first + "." + kMaxVal;
const AddressPtr min_val_ptr = compress_feature_maps.at(min_val_name);
const AddressPtr max_val_ptr = compress_feature_maps.at(max_val_name);
float *fbs_min_val_ptr = reinterpret_cast<float *>(min_val_ptr->addr);
float *fbs_max_val_ptr = reinterpret_cast<float *>(max_val_ptr->addr);
auto fbs_compress_feature_map = schema::CreateCompressFeatureMap(
*(fbb.get()), fbs_compress_weight_fullname, fbs_compress_weight_data, *fbs_min_val_ptr, *fbs_max_val_ptr);
fbs_compress_feature_maps.push_back(fbs_compress_feature_map);
}
}
auto fbs_compress_feature_maps_vector = fbb->CreateVector(fbs_compress_feature_maps);
schema::ResponseFLJobBuilder rsp_fl_job_builder(*(fbb.get()));
rsp_fl_job_builder.add_retcode(static_cast<int>(retcode));
rsp_fl_job_builder.add_reason(fbs_reason);
@ -383,6 +443,10 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
rsp_fl_job_builder.add_next_req_time(fbs_next_req_time);
rsp_fl_job_builder.add_fl_plan_config(fbs_fl_plan);
rsp_fl_job_builder.add_feature_map(fbs_feature_maps_vector);
rsp_fl_job_builder.add_download_compress_type(compressType);
rsp_fl_job_builder.add_compress_feature_map(fbs_compress_feature_maps_vector);
rsp_fl_job_builder.add_upload_compress_type(upload_compress_type);
rsp_fl_job_builder.add_upload_sparse_rate(ps::PSContext::instance()->upload_sparse_rate());
auto rsp_fl_job = rsp_fl_job_builder.Finish();
fbb->Finish(rsp_fl_job);
return;

View File

@ -25,6 +25,9 @@
#include "fl/server/executor.h"
#include "fl/server/kernel/round/round_kernel.h"
#include "fl/server/kernel/round/round_kernel_factory.h"
#include "schema/fl_job_generated.h"
#include "schema/cipher_generated.h"
#include "fl/compression/encode_executor.h"
namespace mindspore {
namespace fl {
@ -56,7 +59,8 @@ class StartFLJobKernel : public RoundKernel {
// Distributed count service counts for startFLJob.
ResultCode CountForStartFLJob(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestFLJob *start_fl_job_req);
void StartFLJob(const std::shared_ptr<FBBuilder> &fbb);
void StartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta,
const schema::RequestFLJob *start_fl_job_req);
bool JudgeFLJobCert(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestFLJob *start_fl_job_req);
@ -65,7 +69,9 @@ class StartFLJobKernel : public RoundKernel {
// Build response for startFLJob round no matter success or failure.
void BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
const std::string &reason, const bool is_selected, const std::string &next_req_time,
std::map<std::string, AddressPtr> feature_maps = {});
const std::map<std::string, AddressPtr> &feature_maps = {},
const schema::CompressType &compressType = schema::CompressType_NO_COMPRESS,
const std::map<std::string, AddressPtr> &compress_feature_maps = {});
// The executor is for getting the initial model for startFLJob request.
Executor *executor_;

View File

@ -201,23 +201,27 @@ ResultCode UpdateModelKernel::VerifyUpdateModel(const schema::RequestUpdateModel
}
std::unordered_map<std::string, size_t> feature_map;
auto upload_feature_map = update_model_req->feature_map();
MS_ERROR_IF_NULL_W_RET_VAL(upload_feature_map, ResultCode::kFail);
for (uint32_t i = 0; i < upload_feature_map->size(); i++) {
const auto &item = upload_feature_map->Get(i);
MS_ERROR_IF_NULL_W_RET_VAL(item, ResultCode::kFail);
MS_ERROR_IF_NULL_W_RET_VAL(item->weight_fullname(), ResultCode::kFail);
MS_ERROR_IF_NULL_W_RET_VAL(item->data(), ResultCode::kFail);
if (ps::PSContext::instance()->upload_compress_type() != kDiffSparseQuant) {
auto upload_feature_map = update_model_req->feature_map();
MS_ERROR_IF_NULL_W_RET_VAL(upload_feature_map, ResultCode::kFail);
for (uint32_t i = 0; i < upload_feature_map->size(); i++) {
const auto &item = upload_feature_map->Get(i);
MS_ERROR_IF_NULL_W_RET_VAL(item, ResultCode::kFail);
MS_ERROR_IF_NULL_W_RET_VAL(item->weight_fullname(), ResultCode::kFail);
MS_ERROR_IF_NULL_W_RET_VAL(item->data(), ResultCode::kFail);
std::string weight_full_name = item->weight_fullname()->str();
size_t weight_size = item->data()->size() * sizeof(float);
feature_map[weight_full_name] = weight_size;
std::string weight_full_name = item->weight_fullname()->str();
size_t weight_size = item->data()->size() * sizeof(float);
feature_map[weight_full_name] = weight_size;
}
}
bool verifyFeatureMapIsSuccess;
if (ps::PSContext::instance()->encrypt_type() == ps::kDSEncryptType && update_model_req->sign() != 0) {
MS_ERROR_IF_NULL_W_RET_VAL(update_model_req->index_array(), ResultCode::kFail);
verifyFeatureMapIsSuccess = VerifySignDSFeatureMap(feature_map, update_model_req);
} else if (ps::PSContext::instance()->upload_compress_type() == kDiffSparseQuant) {
verifyFeatureMapIsSuccess = VerifyUploadCompressFeatureMap(update_model_req);
} else {
verifyFeatureMapIsSuccess = LocalMetaStore::GetInstance().verifyAggregationFeatureMap(feature_map);
}
@ -280,6 +284,45 @@ bool UpdateModelKernel::VerifySignDSFeatureMap(const std::unordered_map<std::str
return true;
}
bool UpdateModelKernel::VerifyUploadCompressFeatureMap(const schema::RequestUpdateModel *update_model_req) {
auto &aggregation_feature_map_ = LocalMetaStore::GetInstance().aggregation_feature_map();
auto upload_sparse_rate = update_model_req->upload_sparse_rate();
if (upload_sparse_rate != ps::PSContext::instance()->upload_sparse_rate()) {
MS_LOG(WARNING) << "The upload_sparse_rate must be equal to the setting in context.";
return false;
}
auto fbs_name_vec = update_model_req->name_vec();
if (fbs_name_vec == nullptr) {
MS_LOG(WARNING) << "The name_vec is null.";
return false;
}
if (fbs_name_vec->size() == 0) {
MS_LOG(WARNING) << "The size of name_vec must be larger than 0.";
return false;
}
if (fbs_name_vec->size() > aggregation_feature_map_.size()) {
MS_LOG(WARNING) << "The size of name_vec must be smaller than model in server.";
return false;
}
for (size_t i = 0; i < fbs_name_vec->size(); ++i) {
std::string name = fbs_name_vec->Get(i)->str();
if (aggregation_feature_map_.count(name) == 0) {
MS_LOG(WARNING) << "The upload name: " << name << " is not in model in server.";
return false;
}
}
auto fbs_compress_feature_map = update_model_req->compress_feature_map();
if (fbs_compress_feature_map == nullptr) {
MS_LOG(WARNING) << "The upload compress feature map is null.";
return false;
}
if (fbs_compress_feature_map->size() == 0) {
MS_LOG(WARNING) << "The upload compress feature map is empty.";
return false;
}
return true;
}
ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_model_req,
const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta) {
MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kFail);
@ -292,6 +335,8 @@ ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *upda
std::map<std::string, UploadData> feature_map;
if (ps::PSContext::instance()->encrypt_type() == ps::kDSEncryptType) {
feature_map = ParseSignDSFeatureMap(update_model_req, data_size, &weight_map);
} else if (ps::PSContext::instance()->upload_compress_type() == kDiffSparseQuant) {
feature_map = ParseUploadCompressFeatureMap(update_model_req, data_size, &weight_map);
} else {
feature_map = ParseFeatureMap(update_model_req);
}
@ -397,6 +442,89 @@ std::map<std::string, UploadData> UpdateModelKernel::ParseSignDSFeatureMap(
return feature_map;
}
std::map<std::string, UploadData> UpdateModelKernel::ParseUploadCompressFeatureMap(
const schema::RequestUpdateModel *update_model_req, size_t data_size,
std::map<std::string, std::vector<float>> *weight_map) {
MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, {});
std::map<std::string, UploadData> feature_map;
schema::CompressType upload_compress_type = update_model_req->upload_compress_type();
upload_compress_type =
mindspore::fl::compression::DecodeExecutor::GetInstance().GetCompressType(upload_compress_type);
MS_LOG(INFO) << "This schema upload compress type is: " << upload_compress_type;
if (upload_compress_type != schema::CompressType_NO_COMPRESS) {
MS_LOG(INFO) << "This upload compress type is DIFF_SPARSE_QUANT.";
feature_map = DecodeFeatureMap(weight_map, update_model_req, upload_compress_type, data_size);
return feature_map;
}
MS_LOG(INFO) << "This upload compress type is NO_COMPRESS.";
// Some clients upload origin weights.
auto fbs_feature_map = update_model_req->feature_map();
MS_ERROR_IF_NULL_W_RET_VAL(fbs_feature_map, feature_map);
for (uint32_t i = 0; i < fbs_feature_map->size(); i++) {
std::string weight_full_name = fbs_feature_map->Get(i)->weight_fullname()->str();
float *weight_data = const_cast<float *>(fbs_feature_map->Get(i)->data()->data());
size_t weight_size = fbs_feature_map->Get(i)->data()->size() * sizeof(float);
UploadData upload_data;
upload_data[kNewWeight].addr = weight_data;
upload_data[kNewWeight].size = weight_size;
feature_map[weight_full_name] = upload_data;
}
return feature_map;
}
std::map<std::string, UploadData> UpdateModelKernel::DecodeFeatureMap(
std::map<std::string, std::vector<float>> *weight_map, const schema::RequestUpdateModel *update_model_req,
schema::CompressType upload_compress_type, size_t data_size) {
std::map<std::string, UploadData> feature_map;
// Get and set decode hyper parameters.
auto seed = update_model_req->iteration();
MS_LOG(INFO) << "The seed for compression is: " << seed;
auto upload_sparse_rate = update_model_req->upload_sparse_rate();
MS_LOG(INFO) << "The upload_sparse_rate for compression is: " << upload_sparse_rate;
// Get name vector.
auto fbs_name_vec = update_model_req->name_vec();
std::vector<std::string> name_vec;
for (size_t i = 0; i < fbs_name_vec->size(); ++i) {
name_vec.emplace_back(fbs_name_vec->Get(i)->str());
}
// Parameter process for decode.
auto fbs_compress_feature_map = update_model_req->compress_feature_map();
std::vector<mindspore::fl::compression::CompressFeatureMap> compress_feature_maps;
for (size_t i = 0; i < fbs_compress_feature_map->size(); ++i) {
mindspore::fl::compression::CompressFeatureMap compress_feature_map;
int8_t *compress_weight_data = const_cast<int8_t *>(fbs_compress_feature_map->Get(i)->compress_data()->data());
size_t compress_weight_size = fbs_compress_feature_map->Get(i)->compress_data()->size();
MS_LOG(INFO) << "The compress weight size: " << compress_weight_size;
for (size_t j = 0; j < compress_weight_size; ++j) {
compress_feature_map.compress_data.emplace_back(compress_weight_data[j]);
}
compress_feature_map.min_val = fbs_compress_feature_map->Get(i)->min_val();
compress_feature_map.max_val = fbs_compress_feature_map->Get(i)->max_val();
MS_LOG(INFO) << "Min value: " << compress_feature_map.min_val;
MS_LOG(INFO) << "Max value: " << compress_feature_map.max_val;
compress_feature_maps.emplace_back(compress_feature_map);
}
// Decode.
bool status = mindspore::fl::compression::DecodeExecutor::GetInstance().Decode(
weight_map, compress_feature_maps, upload_compress_type, upload_sparse_rate, seed, name_vec, data_size);
if (status) {
for (size_t i = 0; i < name_vec.size(); ++i) {
std::string weight_full_name = name_vec[i];
size_t weight_size = (*weight_map)[weight_full_name].size() * sizeof(float);
UploadData upload_data;
upload_data[kNewWeight].addr = (*weight_map)[weight_full_name].data();
upload_data[kNewWeight].size = weight_size;
feature_map[weight_full_name] = upload_data;
}
return feature_map;
}
MS_LOG(WARNING) << "Decode failed!";
return feature_map;
}
ResultCode UpdateModelKernel::CountForAggregation(const std::string &req_fl_id) {
std::string count_reason = "";
if (!DistributedCountService::GetInstance().Count(kCountForAggregation, req_fl_id, &count_reason)) {

View File

@ -30,6 +30,9 @@
#ifdef ENABLE_ARMOUR
#include "fl/armour/cipher/cipher_meta_storage.h"
#endif
#include "fl/compression/decode_executor.h"
#include "schema/fl_job_generated.h"
#include "schema/cipher_generated.h"
namespace mindspore {
namespace fl {
@ -64,8 +67,12 @@ class UpdateModelKernel : public RoundKernel {
std::map<std::string, UploadData> ParseSignDSFeatureMap(const schema::RequestUpdateModel *update_model_req,
size_t data_size,
std::map<std::string, std::vector<float>> *weight_map);
std::map<std::string, UploadData> ParseUploadCompressFeatureMap(
const schema::RequestUpdateModel *update_model_req, size_t data_size,
std::map<std::string, std::vector<float>> *weight_map);
bool VerifySignDSFeatureMap(const std::unordered_map<std::string, size_t> &model,
const schema::RequestUpdateModel *update_model_req);
bool VerifyUploadCompressFeatureMap(const schema::RequestUpdateModel *update_model_req);
ResultCode CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestUpdateModel *update_model_req);
sigVerifyResult VerifySignature(const schema::RequestUpdateModel *update_model_req);
@ -78,6 +85,11 @@ class UpdateModelKernel : public RoundKernel {
// The time window of one iteration.
size_t iteration_time_window_{0};
// Decode functions of compression.
std::map<std::string, UploadData> DecodeFeatureMap(std::map<std::string, std::vector<float>> *weight_map,
const schema::RequestUpdateModel *update_model_req,
schema::CompressType upload_compress_type, size_t data_size);
};
} // namespace kernel
} // namespace server

View File

@ -44,6 +44,11 @@ void MemoryRegister::StoreCharArray(std::unique_ptr<char[]> *array) {
MS_ERROR_IF_NULL_WO_RET_VAL(array);
char_arrays_.push_back(std::move(*array));
}
void MemoryRegister::StoreFloat32(std::unique_ptr<float> *param) {
MS_ERROR_IF_NULL_WO_RET_VAL(param);
float_params_.push_back(std::move(*param));
}
} // namespace server
} // namespace fl
} // namespace mindspore

View File

@ -24,6 +24,7 @@
#include <utility>
#include <typeinfo>
#include "fl/server/common.h"
#include "fl/compression/encode_executor.h"
namespace mindspore {
namespace fl {
@ -70,6 +71,25 @@ class MemoryRegister {
return;
}
template <typename T>
void RegisterParameter(const std::string &name, std::unique_ptr<T> *param, size_t size) {
MS_EXCEPTION_IF_NULL(param);
void *data = param->get();
AddressPtr addressPtr = std::make_shared<Address>();
addressPtr->addr = data;
addressPtr->size = size;
if (typeid(T) == typeid(float)) {
auto float_param = CastUniqueParamPtr<float, T>(param);
StoreFloat32(&float_param);
} else {
MS_LOG(ERROR) << "MemoryRegister does not support type " << typeid(T).name();
return;
}
RegisterAddressPtr(name, addressPtr);
return;
}
private:
std::map<std::string, AddressPtr> addresses_;
std::vector<std::unique_ptr<float[]>> float_arrays_;
@ -86,6 +106,15 @@ class MemoryRegister {
std::unique_ptr<T[]> CastUniquePtr(std::unique_ptr<S[]> *array) {
return std::unique_ptr<T[]>{reinterpret_cast<T *>(array->release())};
}
std::vector<std::unique_ptr<float>> float_params_;
void StoreFloat32(std::unique_ptr<float> *array);
template <typename T, typename S>
std::unique_ptr<T> CastUniqueParamPtr(std::unique_ptr<S> *param) {
return std::unique_ptr<T>{reinterpret_cast<T *>(param->release())};
}
};
} // namespace server
} // namespace fl

View File

@ -19,6 +19,7 @@
#include <string>
#include <memory>
#include "fl/server/executor.h"
#include "pipeline/jit/parse/parse.h"
#include "include/common/utils/python_adapter.h"
namespace mindspore {
@ -33,6 +34,10 @@ void ModelStore::Initialize(uint32_t rank_id, uint32_t max_count) {
max_model_count_ = max_count;
initial_model_ = AssignNewModelMemory();
iteration_to_model_[kInitIterationNum] = initial_model_;
std::map<std::string, AddressPtr> model = Executor::GetInstance().GetModel();
for (const auto &item : mindspore::fl::compression::kCompressTypeMap) {
iteration_to_compress_model_[kInitIterationNum][item.first] = AssignNewCompressModelMemory(item.first, model);
}
model_size_ = ComputeModelSize();
MS_LOG(INFO) << "Model store checkpoint dir is: " << ps::PSContext::instance()->checkpoint_dir();
}
@ -101,6 +106,24 @@ std::map<std::string, AddressPtr> ModelStore::GetModelByIterNum(size_t iteration
return model;
}
std::map<std::string, AddressPtr> ModelStore::GetCompressModelByIterNum(size_t iteration,
schema::CompressType compressType) {
std::unique_lock<std::mutex> lock(model_mtx_);
std::map<std::string, AddressPtr> compressModel = {};
if (iteration_to_compress_model_.count(iteration) == 0) {
MS_LOG(ERROR) << "Compress Model for iteration " << iteration << " is not stored.";
return compressModel;
}
std::map<schema::CompressType, std::shared_ptr<MemoryRegister>> compress_model_map =
iteration_to_compress_model_[iteration];
if (compress_model_map.count(compressType) == 0) {
MS_LOG(ERROR) << "Compress Model for compress type " << compressType << " is not stored.";
return compressModel;
}
compressModel = iteration_to_compress_model_[iteration][compressType]->addresses();
return compressModel;
}
void ModelStore::Reset() {
std::unique_lock<std::mutex> lock(model_mtx_);
initial_model_ = iteration_to_model_.rbegin()->second;
@ -114,6 +137,11 @@ const std::map<size_t, std::shared_ptr<MemoryRegister>> &ModelStore::iteration_t
return iteration_to_model_;
}
const std::map<size_t, CompressTypeMap> &ModelStore::iteration_to_compress_model() {
std::unique_lock<std::mutex> lock(model_mtx_);
return iteration_to_compress_model_;
}
size_t ModelStore::model_size() const { return model_size_; }
std::shared_ptr<MemoryRegister> ModelStore::AssignNewModelMemory() {
@ -146,6 +174,86 @@ std::shared_ptr<MemoryRegister> ModelStore::AssignNewModelMemory() {
return memory_register;
}
std::shared_ptr<MemoryRegister> ModelStore::AssignNewCompressModelMemory(
schema::CompressType compressType, const std::map<std::string, AddressPtr> &model) {
if (model.empty()) {
MS_LOG(EXCEPTION) << "Model feature map is empty.";
return nullptr;
}
std::map<string, std::vector<float>> feature_maps;
for (auto &feature_map : model) {
auto weight_fullname = feature_map.first;
auto weight_data = reinterpret_cast<float *>(feature_map.second->addr);
std::vector<float> weight_data_vector{weight_data, weight_data + feature_map.second->size / sizeof(float)};
feature_maps[weight_fullname] = weight_data_vector;
}
std::map<std::string, mindspore::fl::compression::CompressWeight> compressWeights;
bool status = mindspore::fl::compression::CompressExecutor::GetInstance().construct_compress_weight(
&compressWeights, feature_maps, compressType);
if (!status) {
MS_LOG(ERROR) << "Encode failed!";
return nullptr;
}
// Assign new memory for the compress model.
std::shared_ptr<MemoryRegister> memory_register = std::make_shared<MemoryRegister>();
MS_ERROR_IF_NULL_W_RET_VAL(memory_register, nullptr);
MS_LOG(INFO) << "Register compressWeight for compressType: " << schema::EnumNameCompressType(compressType);
for (const auto &compressWeight : compressWeights) {
if (compressType == schema::CompressType_QUANT) {
std::string compress_weight_name = compressWeight.first;
std::string min_val_name = compress_weight_name + "." + kMinVal;
std::string max_val_name = compress_weight_name + "." + kMaxVal;
size_t compress_weight_size = compressWeight.second.compress_data_len * sizeof(int8_t);
auto compress_weight_data = std::make_unique<char[]>(compress_weight_size);
auto src_data_size = compress_weight_size;
auto dst_data_size = compress_weight_size;
int ret =
memcpy_s(compress_weight_data.get(), dst_data_size, compressWeight.second.compress_data.data(), src_data_size);
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
return nullptr;
}
memory_register->RegisterArray(compress_weight_name, &compress_weight_data, compress_weight_size);
size_t float_size = 1;
auto min_val_ptr = std::make_unique<float>(compressWeight.second.min_val);
auto max_val_ptr = std::make_unique<float>(compressWeight.second.max_val);
memory_register->RegisterParameter(min_val_name, &min_val_ptr, float_size);
memory_register->RegisterParameter(max_val_name, &max_val_ptr, float_size);
}
}
return memory_register;
}
void ModelStore::StoreCompressModelByIterNum(size_t iteration, const std::map<std::string, AddressPtr> &new_model) {
std::unique_lock<std::mutex> lock(model_mtx_);
if (iteration_to_compress_model_.count(iteration) != 0) {
MS_LOG(WARNING) << "Compress Model for iteration " << iteration << " is already stored";
return;
}
if (new_model.empty()) {
MS_LOG(ERROR) << "Compress Model feature map is empty.";
return;
}
iteration_to_compress_model_[iteration] = {};
if (iteration_to_compress_model_.size() >= max_model_count_) {
auto compress_model_map = iteration_to_compress_model_.begin()->second;
compress_model_map.clear();
(void)iteration_to_compress_model_.erase(iteration_to_compress_model_.begin());
}
for (const auto &item : mindspore::fl::compression::kCompressTypeMap) {
auto memory_register = AssignNewCompressModelMemory(item.first, new_model);
MS_ERROR_IF_NULL_WO_RET_VAL(memory_register);
iteration_to_compress_model_[iteration][item.first] = memory_register;
}
return;
}
size_t ModelStore::ComputeModelSize() {
std::unique_lock<std::mutex> lock(model_mtx_);
if (iteration_to_model_.empty()) {
@ -179,13 +287,15 @@ void ModelStore::RelModelResponseCache(const void *data, size_t datalen, void *e
std::shared_ptr<std::vector<uint8_t>> ModelStore::GetModelResponseCache(const string &round_name,
size_t cur_iteration_num,
size_t model_iteration_num) {
size_t model_iteration_num,
const std::string &compress_type) {
std::unique_lock<std::mutex> lock(model_response_cache_lock_);
auto it = std::find_if(model_response_cache_.begin(), model_response_cache_.end(),
[&round_name, cur_iteration_num, model_iteration_num](const HttpResponseModelCache &item) {
return item.round_name == round_name && item.cur_iteration_num == cur_iteration_num &&
item.model_iteration_num == model_iteration_num;
});
auto it = std::find_if(
model_response_cache_.begin(), model_response_cache_.end(),
[&round_name, cur_iteration_num, model_iteration_num, &compress_type](const HttpResponseModelCache &item) {
return item.round_name == round_name && item.cur_iteration_num == cur_iteration_num &&
item.model_iteration_num == model_iteration_num && item.compress_type == compress_type;
});
if (it == model_response_cache_.end()) {
return nullptr;
}
@ -196,14 +306,16 @@ std::shared_ptr<std::vector<uint8_t>> ModelStore::GetModelResponseCache(const st
std::shared_ptr<std::vector<uint8_t>> ModelStore::StoreModelResponseCache(const string &round_name,
size_t cur_iteration_num,
size_t model_iteration_num, const void *data,
size_t datalen) {
size_t model_iteration_num,
const std::string &compress_type,
const void *data, size_t datalen) {
std::unique_lock<std::mutex> lock(model_response_cache_lock_);
auto it = std::find_if(model_response_cache_.begin(), model_response_cache_.end(),
[&round_name, cur_iteration_num, model_iteration_num](const HttpResponseModelCache &item) {
return item.round_name == round_name && item.cur_iteration_num == cur_iteration_num &&
item.model_iteration_num == model_iteration_num;
});
auto it = std::find_if(
model_response_cache_.begin(), model_response_cache_.end(),
[&round_name, cur_iteration_num, model_iteration_num, &compress_type](const HttpResponseModelCache &item) {
return item.round_name == round_name && item.cur_iteration_num == cur_iteration_num &&
item.model_iteration_num == model_iteration_num && item.compress_type == compress_type;
});
if (it != model_response_cache_.end()) {
it->reference_count += 1;
total_add_reference_count += 1;
@ -223,6 +335,7 @@ std::shared_ptr<std::vector<uint8_t>> ModelStore::StoreModelResponseCache(const
item.round_name = round_name;
item.cur_iteration_num = cur_iteration_num;
item.model_iteration_num = model_iteration_num;
item.compress_type = compress_type;
item.cache = cache;
item.reference_count = 1;
total_add_reference_count += 1;

View File

@ -25,6 +25,7 @@
#include "fl/server/common.h"
#include "fl/server/memory_register.h"
#include "fl/server/executor.h"
#include "fl/compression/encode_executor.h"
#include "fl/server/local_meta_store.h"
namespace mindspore {
@ -36,6 +37,9 @@ constexpr size_t kInitIterationNum = 0;
// The initial iteration number after ModelStore is reset.
constexpr size_t kResetInitialIterNum = 1;
// The compress type map.
using CompressTypeMap = std::map<schema::CompressType, std::shared_ptr<MemoryRegister>>;
// Server framework use ModelStore to store and query models.
// ModelStore stores multiple models because worker could get models of the previous iterations.
class ModelStore {
@ -64,15 +68,25 @@ class ModelStore {
// Returns the model size, which could be calculated at the initializing phase.
size_t model_size() const;
// Get compress model of the given iteration.
std::map<std::string, AddressPtr> GetCompressModelByIterNum(size_t iteration, schema::CompressType compressType);
const std::map<size_t, std::map<schema::CompressType, std::shared_ptr<MemoryRegister>>>
&iteration_to_compress_model();
void StoreCompressModelByIterNum(size_t iteration, const std::map<std::string, AddressPtr> &new_model);
static void RelModelResponseCache(const void *data, size_t datalen, void *extra);
std::shared_ptr<std::vector<uint8_t>> GetModelResponseCache(const std::string &round_name, size_t cur_iteration_num,
size_t model_iteration_num);
size_t model_iteration_num,
const std::string &compress_type);
std::shared_ptr<std::vector<uint8_t>> StoreModelResponseCache(const std::string &round_name, size_t cur_iteration_num,
size_t model_iteration_num, const void *data,
size_t model_iteration_num,
const std::string &compress_type, const void *data,
size_t datalen);
private:
ModelStore() : max_model_count_(0), model_size_(0), iteration_to_model_({}) {}
ModelStore() : max_model_count_(0), model_size_(0), iteration_to_model_({}), iteration_to_compress_model_({}) {}
~ModelStore() = default;
ModelStore(const ModelStore &) = delete;
ModelStore &operator=(const ModelStore &) = delete;
@ -83,6 +97,9 @@ class ModelStore {
// model_size_.
std::shared_ptr<MemoryRegister> AssignNewModelMemory();
std::shared_ptr<MemoryRegister> AssignNewCompressModelMemory(schema::CompressType compressType,
const std::map<std::string, AddressPtr> &model);
// Calculate the model size. This method should be called after iteration_to_model_ is initialized.
size_t ComputeModelSize();
@ -95,12 +112,17 @@ class ModelStore {
// The number of all models stored is max_model_count_.
std::mutex model_mtx_;
std::map<size_t, std::shared_ptr<MemoryRegister>> iteration_to_model_;
// iteration -> (compress type -> compress model)
std::map<size_t, std::map<schema::CompressType, std::shared_ptr<MemoryRegister>>> iteration_to_compress_model_;
uint32_t rank_id_;
struct HttpResponseModelCache {
std::string round_name; // startFlJob, getModel
size_t cur_iteration_num = 0;
size_t model_iteration_num = 0;
std::string compress_type = kNoCompress;
size_t reference_count = 0;
std::shared_ptr<std::vector<uint8_t>> cache = nullptr;
};

View File

@ -507,6 +507,12 @@ PYBIND11_MODULE(_c_expression, m) {
.def("set_global_iteration_time_window", &PSContext::set_global_iteration_time_window,
"Set global iteration time window.")
.def("global_iteration_time_window", &PSContext::global_iteration_time_window, "Get global iteration time window.")
.def("set_upload_compress_type", &PSContext::set_upload_compress_type, "Set upload compress type.")
.def("upload_compress_type", &PSContext::upload_compress_type, "Get upload compress type.")
.def("set_upload_sparse_rate", &PSContext::set_upload_sparse_rate, "Set upload sparse rate.")
.def("upload_sparse_rate", &PSContext::upload_sparse_rate, "Get upload sparse rate.")
.def("set_download_compress_type", &PSContext::set_download_compress_type, "Set download compress type.")
.def("download_compress_type", &PSContext::download_compress_type, "Get download compress type.")
.def("set_checkpoint_dir", &PSContext::set_checkpoint_dir, "Set server checkpoint directory.")
.def("checkpoint_dir", &PSContext::checkpoint_dir, "Server checkpoint directory.");
(void)m.def("_encrypt", &mindspore::pipeline::PyEncrypt, "Encrypt the data.");

View File

@ -550,6 +550,19 @@ void PSContext::set_global_iteration_time_window(const uint64_t &global_iteratio
uint64_t PSContext::global_iteration_time_window() const { return global_iteration_time_window_; }
void PSContext::set_upload_compress_type(const std::string &upload_compress_type) {
upload_compress_type_ = upload_compress_type;
}
std::string PSContext::upload_compress_type() const { return upload_compress_type_; }
void PSContext::set_upload_sparse_rate(float upload_sparse_rate) { upload_sparse_rate_ = upload_sparse_rate; }
float PSContext::upload_sparse_rate() const { return upload_sparse_rate_; }
void PSContext::set_download_compress_type(const std::string &download_compress_type) {
download_compress_type_ = download_compress_type;
}
std::string PSContext::download_compress_type() const { return download_compress_type_; }
std::string PSContext::checkpoint_dir() const { return checkpoint_dir_; }
void PSContext::set_checkpoint_dir(const std::string &checkpoint_dir) { checkpoint_dir_ = checkpoint_dir; }

View File

@ -40,6 +40,7 @@ constexpr char kPWEncryptType[] = "PW_ENCRYPT";
constexpr char kStablePWEncryptType[] = "STABLE_PW_ENCRYPT";
constexpr char kNotEncryptType[] = "NOT_ENCRYPT";
constexpr char kDSEncryptType[] = "SIGNDS";
constexpr char kNoCompressType[] = "NO_COMPRESS";
// Use binary data to represent federated learning server's context so that we can judge which round resets the
// iteration. From right to left, each bit stands for:
@ -230,6 +231,15 @@ class PSContext {
void set_global_iteration_time_window(const uint64_t &global_iteration_time_window);
uint64_t global_iteration_time_window() const;
void set_upload_compress_type(const std::string &upload_compress_type);
std::string upload_compress_type() const;
void set_upload_sparse_rate(float upload_sparse_rate);
float upload_sparse_rate() const;
void set_download_compress_type(const std::string &download_compress_type);
std::string download_compress_type() const;
std::string checkpoint_dir() const;
void set_checkpoint_dir(const std::string &checkpoint_dir);
@ -286,6 +296,9 @@ class PSContext {
server_password_(""),
http_url_prefix_(""),
global_iteration_time_window_(3600000),
upload_compress_type_(kNoCompressType),
upload_sparse_rate_(0.4f),
download_compress_type_(kNoCompressType),
checkpoint_dir_("") {}
bool ps_enabled_;
bool is_worker_;
@ -419,6 +432,13 @@ class PSContext {
// The time window of startFLJob round in millisecond.
uint64_t global_iteration_time_window_;
// Hyper parameters for upload compression.
std::string upload_compress_type_;
float upload_sparse_rate_;
// Hyper parameters for download compression.
std::string download_compress_type_;
// directory of server checkpoint
std::string checkpoint_dir_;
};

View File

@ -105,6 +105,16 @@ public class FLLiteClient {
batchSize = flPlan.miniBatch();
String serverMod = flPlan.serverMode();
localFLParameter.setServerMod(serverMod);
// Get and set hyper parameters for compression.
byte uploadCompressType = flJob.uploadCompressType();
LOGGER.info(Common.addTag("[startFLJob] [compression] uploadCompressType: " + uploadCompressType));
localFLParameter.setUploadCompressType(uploadCompressType);
float uploadSparseRate = flJob.uploadSparseRate();
LOGGER.info(Common.addTag("[startFLJob] [compression] uploadSparseRate: " + uploadSparseRate));
localFLParameter.setUploadSparseRatio(uploadSparseRate);
int seed = flJob.iteration();
LOGGER.info(Common.addTag("[startFLJob] [compression] seed: " + seed));
localFLParameter.setSeed(seed);
if (Common.checkFLName(flParameter.getFlName())) {
deprecatedSetBatchSize(batchSize);
} else {
@ -446,7 +456,7 @@ public class FLLiteClient {
return status;
}
private Map<String, float[]> getFeatureMap() {
public Map<String, float[]> getFeatureMap() {
Map<String, float[]> featureMap = new HashMap<>();
if (Common.checkFLName(flParameter.getFlName())) {
featureMap = deprecatedGetFeatureMap();
@ -530,8 +540,7 @@ public class FLLiteClient {
localFLParameter.getEncryptLevel().toString() + "> : " + curStatus));
return curStatus;
case DP_ENCRYPT:
// get the feature map before train
oldFeatureMap = getFeatureMap();
oldFeatureMap = localFLParameter.getOldFeatureMap();
curStatus = secureProtocol.setDPParameter(iteration, dpEps, dpDelta, dpNormClipAdapt, oldFeatureMap);
retCode = ResponseCode.SUCCEED;
if (curStatus != FLClientStatus.SUCCESS) {
@ -542,8 +551,7 @@ public class FLLiteClient {
LOGGER.info(Common.addTag("[Encrypt] set parameters for DP_ENCRYPT!"));
return FLClientStatus.SUCCESS;
case SIGNDS:
// get the feature map before train
oldFeatureMap = getFeatureMap();
oldFeatureMap = localFLParameter.getOldFeatureMap();
curStatus = secureProtocol.setDSParameter(signK, signEps, signThrRatio, signGlobalLr, signDimOut, oldFeatureMap);
retCode = ResponseCode.SUCCEED;
if (curStatus != FLClientStatus.SUCCESS) {

View File

@ -18,7 +18,9 @@ package com.mindspore.flclient;
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
import com.mindspore.flclient.compression.CompressMode;
import com.mindspore.flclient.model.RunType;
import mindspore.schema.CompressType;
import java.util.ArrayList;
import java.util.Arrays;
@ -603,6 +605,16 @@ public class FLParameter {
this.batchSize = batchSize;
}
public byte[] getDownloadCompressTypes() {
byte[] downloadCompressTypes = new byte[CompressMode.COMPRESS_TYPE_MAP.size()];
int index = 0;
for (byte downloadCompressType : CompressMode.COMPRESS_TYPE_MAP.keySet()) {
downloadCompressTypes[index] = downloadCompressType;
index += 1;
}
return downloadCompressTypes;
}
public int[][] getInputShape() {
return inputShape;
}

View File

@ -18,6 +18,7 @@ package com.mindspore.flclient;
import com.google.flatbuffers.FlatBufferBuilder;
import com.mindspore.flclient.compression.DecodeExecutor;
import com.mindspore.flclient.model.AlInferBert;
import com.mindspore.flclient.model.AlTrainBert;
import com.mindspore.flclient.model.Client;
@ -27,11 +28,9 @@ import com.mindspore.flclient.model.SessionUtil;
import com.mindspore.flclient.model.Status;
import com.mindspore.flclient.model.TrainLenet;
import mindspore.schema.FeatureMap;
import mindspore.schema.RequestGetModel;
import mindspore.schema.ResponseCode;
import mindspore.schema.ResponseGetModel;
import mindspore.schema.*;
import java.util.List;
import java.util.ArrayList;
import java.util.Date;
import java.util.logging.Logger;
@ -94,7 +93,8 @@ public class GetModel {
throw new IllegalArgumentException();
}
RequestGetModelBuilder builder = new RequestGetModelBuilder();
return builder.iteration(iteration).flName(name).time().build();
return builder.iteration(iteration).flName(name).time()
.downloadCompressTypesBuilder(flParameter.getDownloadCompressTypes()).build();
}
private FLClientStatus deprecatedParseResponseAlbert(ResponseGetModel responseDataBuf) {
@ -226,11 +226,29 @@ public class GetModel {
return status;
}
private List<FeatureMap> parseFeatureMapList(ResponseGetModel responseDataBuf) {
List<FeatureMap> featureMaps;
byte compressType = responseDataBuf.downloadCompressType();
if (responseDataBuf.downloadCompressType() == mindspore.schema.CompressType.NO_COMPRESS) {
featureMaps = new ArrayList<>();
for (int i = 0; i < responseDataBuf.featureMapLength(); i++) {
featureMaps.add(responseDataBuf.featureMap(i));
}
} else {
List<mindspore.schema.CompressFeatureMap> compressFeatureMapList = new ArrayList<>();
for (int i = 0; i < responseDataBuf.compressFeatureMapLength(); i++) {
compressFeatureMapList.add(responseDataBuf.compressFeatureMap(i));
}
featureMaps = DecodeExecutor.getInstance().deCompressWeight(compressType, compressFeatureMapList);
}
return featureMaps;
}
private FLClientStatus parseResponseFeatures(ResponseGetModel responseDataBuf) {
FLClientStatus status;
Client client = ClientManager.getClient(flParameter.getFlName());
int fmCount = responseDataBuf.featureMapLength();
if (fmCount <= 0) {
List<FeatureMap> featureMapList = parseFeatureMapList(responseDataBuf);
if (featureMapList.size() <= 0) {
LOGGER.severe(Common.addTag("[getModel] the feature size get from server is zero"));
retCode = ResponseCode.SystemError;
return FLClientStatus.FAILED;
@ -239,8 +257,8 @@ public class GetModel {
LOGGER.info(Common.addTag("[getModel] parseResponseFeatures by " + localFLParameter.getServerMod()));
ArrayList<FeatureMap> trainFeatureMaps = new ArrayList<FeatureMap>();
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
for (int i = 0; i < fmCount; i++) {
FeatureMap feature = responseDataBuf.featureMap(i);
for (int i = 0; i < featureMapList.size(); i++) {
FeatureMap feature = featureMapList.get(i);
if (feature == null) {
LOGGER.severe(Common.addTag("[getModel] the feature returned from server is null"));
retCode = ResponseCode.SystemError;
@ -289,8 +307,8 @@ public class GetModel {
} else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) {
LOGGER.info(Common.addTag("[getModel] parseResponseFeatures by " + localFLParameter.getServerMod()));
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
for (int i = 0; i < fmCount; i++) {
FeatureMap feature = responseDataBuf.featureMap(i);
for (int i = 0; i < featureMapList.size(); i++) {
FeatureMap feature = featureMapList.get(i);
if (feature == null) {
LOGGER.severe(Common.addTag("[getModel] the feature returned from server is null"));
retCode = ResponseCode.SystemError;
@ -365,6 +383,7 @@ public class GetModel {
private int nameOffset = 0;
private int iteration = 0;
private int timeStampOffset = 0;
private int downloadCompressTypesOffset = 0;
public RequestGetModelBuilder() {
builder = new FlatBufferBuilder();
@ -392,11 +411,23 @@ public class GetModel {
return this;
}
private RequestGetModelBuilder downloadCompressTypesBuilder(byte[] downloadCompressTypes) {
if (downloadCompressTypes == null || downloadCompressTypes.length == 0) {
LOGGER.severe(Common.addTag("[GetModel] the parameter of <downloadCompressTypes> is null or empty," +
" please check!"));
throw new IllegalArgumentException();
}
this.downloadCompressTypesOffset = RequestGetModel.createDownloadCompressTypesVector(builder,
downloadCompressTypes);
return this;
}
private byte[] build() {
RequestGetModel.startRequestGetModel(builder);
RequestGetModel.addFlName(builder, nameOffset);
RequestGetModel.addIteration(builder, iteration);
RequestGetModel.addTimestamp(builder, timeStampOffset);
RequestGetModel.addDownloadCompressTypes(builder, downloadCompressTypesOffset);
int root = RequestGetModel.endRequestGetModel(builder);
builder.finish(root);
return builder.sizedByteArray();

View File

@ -22,6 +22,7 @@ import org.bouncycastle.math.ec.rfc7748.X25519;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
/**
@ -83,6 +84,10 @@ public class LocalFLParameter {
private MSConfig msConfig = new MSConfig();
private boolean useSSL = true;
private float lr = 0.1f;
private Map<String, float[]> oldFeatureMap;
private byte uploadCompressType = 0;
private int seed = 0;
private float uploadSparseRatio = 0.08f;
private LocalFLParameter() {
@ -250,4 +255,36 @@ public class LocalFLParameter {
public void setLr(float lr) {
this.lr = lr;
}
public Map<String, float[]> getOldFeatureMap() {
return oldFeatureMap;
}
public void setOldFeatureMap(Map<String, float[]> oldFeatureMap) {
this.oldFeatureMap = oldFeatureMap;
}
public byte getUploadCompressType() {
return uploadCompressType;
}
public void setUploadCompressType(byte uploadCompressType) {
this.uploadCompressType = uploadCompressType;
}
public int getSeed() {
return seed;
}
public void setSeed(int seed) {
this.seed = seed;
}
public float getUploadSparseRatio() {
return uploadSparseRatio;
}
public void setUploadSparseRatio(float uploadSparseRatio) {
this.uploadSparseRatio = uploadSparseRatio;
}
}

View File

@ -208,35 +208,34 @@ public class SecureProtocol {
* @param trainDataSize trainDataSize tne size of train data set.
* @return the serialized model weights after adding masks.
*/
public int[] pwMaskModel(FlatBufferBuilder builder, int trainDataSize, Map<String, float[]> trainedMap) {
public Map<String, List<Float>> pwMaskModel(FlatBufferBuilder builder, int trainDataSize, Map<String,
float[]> trainedMap) {
Map<String, List<Float>> featureMaps = new HashMap<>();
if (featureMask == null || featureMask.length == 0) {
LOGGER.severe("[Encrypt] feature mask is null, please check");
return new int[0];
return new HashMap<>();
}
LOGGER.info(String.format("[Encrypt] feature mask size: %s", featureMask.length));
int featureSize = updateFeatureName.size();
int[] featuresMap = new int[featureSize];
int maskIndex = 0;
for (int i = 0; i < featureSize; i++) {
String key = updateFeatureName.get(i);
float[] data = trainedMap.get(key);
List<Float> featureMap = new ArrayList<>();
LOGGER.info(String.format("[Encrypt] feature name: %s feature size: %s", key, data.length));
for (int j = 0; j < data.length; j++) {
float rawData = data[j];
if (maskIndex >= featureMask.length) {
LOGGER.severe("[Encrypt] the maskIndex is out of range for array featureMask, please check");
return new int[0];
return new HashMap<>();
}
float maskData = rawData * trainDataSize + featureMask[maskIndex];
maskIndex += 1;
data[j] = maskData;
featureMap.add(maskData);
}
int featureName = builder.createString(key);
int weight = FeatureMap.createDataVector(builder, data);
int featureMap = FeatureMap.createFeatureMap(builder, featureName, weight);
featuresMap[i] = featureMap;
featureMaps.put(key, featureMap);
}
return featuresMap;
return featureMaps;
}
/**
@ -365,7 +364,9 @@ public class SecureProtocol {
* @param trainDataSize tne size of train data set.
* @return the serialized model weights after adding masks.
*/
public int[] dpMaskModel(FlatBufferBuilder builder, int trainDataSize, Map<String, float[]> trainedMap) {
public Map<String, List<Float>> dpMaskModel(FlatBufferBuilder builder, int trainDataSize,
Map<String, float[]> trainedMap) {
Map<String, List<Float>> featureMaps = new HashMap<>();
// get feature map
Map<String, float[]> mapBeforeTrain = modelMap;
int featureSize = updateFeatureName.size();
@ -383,7 +384,7 @@ public class SecureProtocol {
float rawData = data[j];
if (j >= dataBeforeTrain.length) {
LOGGER.severe("[Encrypt] the index j is out of range for array dataBeforeTrain, please check");
return new int[0];
return new HashMap<>();
}
float rawDataBeforeTrain = dataBeforeTrain[j];
float updateData = rawData - rawDataBeforeTrain;
@ -393,23 +394,23 @@ public class SecureProtocol {
updateL2Norm = Math.sqrt(updateL2Norm);
if (updateL2Norm == 0) {
LOGGER.severe(Common.addTag("[Encrypt] updateL2Norm is 0, please check"));
return new int[0];
return new HashMap<>();
}
double clipFactor = Math.min(1.0, dpNormClip / updateL2Norm);
// clip and add noise
int[] featuresMap = new int[featureSize];
for (int i = 0; i < featureSize; i++) {
String key = updateFeatureName.get(i);
if (!trainedMap.containsKey(key)) {
LOGGER.severe("[Encrypt] the key: " + key + " is not in map, please check!");
return new int[0];
return new HashMap<>();
}
float[] data = trainedMap.get(key);
float[] data2 = new float[data.length];
List<Float> featureMap = new ArrayList<>();
if (!mapBeforeTrain.containsKey(key)) {
LOGGER.severe("[Encrypt] the key: " + key + " is not in mapBeforeTrain, please check!");
return new int[0];
return new HashMap<>();
}
float[] dataBeforeTrain = mapBeforeTrain.get(key);
@ -419,7 +420,7 @@ public class SecureProtocol {
float rawData = data[j];
if (j >= dataBeforeTrain.length) {
LOGGER.severe("[Encrypt] the index j is out of range for array dataBeforeTrain, please check");
return new int[0];
return new HashMap<>();
}
float rawDataBeforeTrain = dataBeforeTrain[j];
float updateData = rawData - rawDataBeforeTrain;
@ -432,13 +433,11 @@ public class SecureProtocol {
updateData += gaussianNoise;
data2[j] = rawDataBeforeTrain + updateData;
data2[j] = data2[j] * trainDataSize;
featureMap.add(data2[j]);
}
int featureName = builder.createString(key);
int weight = FeatureMap.createDataVector(builder, data2);
int featureMap = FeatureMap.createFeatureMap(builder, featureName, weight);
featuresMap[i] = featureMap;
featureMaps.put(key, featureMap);
}
return featuresMap;
return featureMaps;
}
/**

View File

@ -18,6 +18,7 @@ package com.mindspore.flclient;
import com.google.flatbuffers.FlatBufferBuilder;
import com.mindspore.flclient.compression.DecodeExecutor;
import com.mindspore.flclient.model.AlInferBert;
import com.mindspore.flclient.model.AlTrainBert;
import com.mindspore.flclient.model.Client;
@ -29,6 +30,7 @@ import com.mindspore.flclient.model.TrainLenet;
import com.mindspore.flclient.pki.PkiBean;
import com.mindspore.flclient.pki.PkiUtil;
import mindspore.schema.*;
import mindspore.schema.FLPlan;
import mindspore.schema.FeatureMap;
import mindspore.schema.RequestFLJob;
@ -38,6 +40,7 @@ import mindspore.schema.ResponseFLJob;
import java.io.IOException;
import java.security.cert.Certificate;
import java.util.ArrayList;
import java.util.List;
import java.util.logging.Logger;
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
@ -119,6 +122,7 @@ public class StartFLJob {
.iteration(iteration)
.signData(pkiBean.getSignData())
.certificateChain(pkiBean.getCertificates())
.downloadCompressTypesBuilder(flParameter.getDownloadCompressTypes())
.build();
}
return builder.flName(flParameter.getFlName())
@ -126,6 +130,7 @@ public class StartFLJob {
.id(localFLParameter.getFlID())
.dataSize(dataSize)
.iteration(iteration)
.downloadCompressTypesBuilder(flParameter.getDownloadCompressTypes())
.build();
}
@ -151,8 +156,9 @@ public class StartFLJob {
ArrayList<FeatureMap> albertFeatureMaps = new ArrayList<FeatureMap>();
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
featureSize = 0;
for (int i = 0; i < fmCount; i++) {
FeatureMap feature = flJob.featureMap(i);
List<FeatureMap> featureMapList = parseFeatureMapList(flJob);
for (int i = 0; i < featureMapList.size(); i++) {
FeatureMap feature = featureMapList.get(i);
if (feature == null) {
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
return FLClientStatus.FAILED;
@ -233,12 +239,14 @@ public class StartFLJob {
private FLClientStatus deprecatedParseResponseLenet(ResponseFLJob flJob) {
FLClientStatus status;
int fmCount = flJob.featureMapLength();
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
updateFeatureName.clear();
featureSize = 0;
for (int i = 0; i < fmCount; i++) {
FeatureMap feature = flJob.featureMap(i);
List<FeatureMap> featureMapList = parseFeatureMapList(flJob);
ArrayList<FeatureMap> featureMaps = new ArrayList<>();
for (int i = 0; i < featureMapList.size(); i++) {
FeatureMap feature = featureMapList.get(i);
if (feature == null) {
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
return FLClientStatus.FAILED;
@ -267,6 +275,24 @@ public class StartFLJob {
return FLClientStatus.SUCCESS;
}
private List<FeatureMap> parseFeatureMapList(ResponseFLJob flJob) {
List<FeatureMap> featureMaps;
byte compressType = flJob.downloadCompressType();
if (flJob.downloadCompressType() == mindspore.schema.CompressType.NO_COMPRESS) {
LOGGER.info(Common.addTag("[parseFeatureMapList] create no compress feature map."));
featureMaps = new ArrayList<>();
for (int i = 0; i < flJob.featureMapLength(); i++) {
featureMaps.add(flJob.featureMap(i));
}
} else {
List<CompressFeatureMap> compressFeatureMapList = new ArrayList<>();
for (int i = 0; i < flJob.compressFeatureMapLength(); i++) {
compressFeatureMapList.add(flJob.compressFeatureMap(i));
}
featureMaps = DecodeExecutor.getInstance().deCompressWeight(compressType, compressFeatureMapList);
}
return featureMaps;
}
private FLClientStatus hybridFeatures(ResponseFLJob flJob) {
FLClientStatus status;
@ -275,8 +301,23 @@ public class StartFLJob {
ArrayList<FeatureMap> trainFeatureMaps = new ArrayList<FeatureMap>();
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
featureSize = 0;
List<FeatureMap> featureMaps;
byte compressType = flJob.downloadCompressType();
if (compressType == CompressType.NO_COMPRESS) {
featureMaps = new ArrayList<>();
for (int i = 0; i < fmCount; i++) {
featureMaps.add(flJob.featureMap(i));
}
} else {
List<CompressFeatureMap> compressFeatureMapList = new ArrayList<>();
for (int i = 0; i < flJob.compressFeatureMapLength(); i++) {
compressFeatureMapList.add(flJob.compressFeatureMap(i));
}
featureMaps = DecodeExecutor.getInstance().deCompressWeight(compressType, compressFeatureMapList);
fmCount = featureMaps.size();
}
for (int i = 0; i < fmCount; i++) {
FeatureMap feature = flJob.featureMap(i);
FeatureMap feature = featureMaps.get(i);
if (feature == null) {
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
retCode = ResponseCode.SystemError;
@ -335,8 +376,23 @@ public class StartFLJob {
int fmCount = flJob.featureMapLength();
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
featureSize = 0;
byte compressType = flJob.downloadCompressType();
List<FeatureMap> parseFeatureMaps;
if (compressType == CompressType.NO_COMPRESS) {
parseFeatureMaps = new ArrayList<>();
for (int i = 0; i < fmCount; i++) {
parseFeatureMaps.add(flJob.featureMap(i));
}
} else {
List<CompressFeatureMap> compressFeatureMapList = new ArrayList<>();
for (int i = 0; i < flJob.compressFeatureMapLength(); i++) {
compressFeatureMapList.add(flJob.compressFeatureMap(i));
}
parseFeatureMaps = DecodeExecutor.getInstance().deCompressWeight(compressType, compressFeatureMapList);
fmCount = parseFeatureMaps.size();
}
for (int i = 0; i < fmCount; i++) {
FeatureMap feature = flJob.featureMap(i);
FeatureMap feature = parseFeatureMaps.get(i);
if (feature == null) {
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
retCode = ResponseCode.SystemError;
@ -437,8 +493,8 @@ public class StartFLJob {
switch (responseRetCode) {
case (ResponseCode.SUCCEED):
if (flJob.featureMapLength() <= 0) {
LOGGER.severe(Common.addTag("[startFLJob] the feature size get from server is zero"));
if (flJob.downloadCompressType() == CompressType.NO_COMPRESS && flJob.featureMapLength() <= 0) {
LOGGER.warning(Common.addTag("[startFLJob] the feature size get from server is zero"));
retCode = ResponseCode.SystemError;
return FLClientStatus.FAILED;
}
@ -484,6 +540,7 @@ public class StartFLJob {
private int equipCertOffset = 0;
private int equipCACertOffset = 0;
private int rootCertOffset = 0;
private int downloadCompressTypesOffset = 0;
public RequestStartFLJobBuilder() {
builder = new FlatBufferBuilder();
@ -598,6 +655,17 @@ public class StartFLJob {
return this;
}
private RequestStartFLJobBuilder downloadCompressTypesBuilder(byte[] downloadCompressTypes) {
if (downloadCompressTypes == null || downloadCompressTypes.length == 0) {
LOGGER.severe(Common.addTag("[StartFLJob] the parameter of <downloadCompressTypes> is null or empty," +
" please check!"));
throw new IllegalArgumentException();
}
this.downloadCompressTypesOffset = RequestFLJob.createDownloadCompressTypesVector(builder,
downloadCompressTypes);
return this;
}
/**
* build protobuffer
*
@ -615,6 +683,7 @@ public class StartFLJob {
RequestFLJob.addEquipCaCert(builder, equipCACertOffset);
RequestFLJob.addEquipCert(builder, equipCertOffset);
RequestFLJob.addKeyAttestation(builder, keyAttestationOffset);
RequestFLJob.addDownloadCompressTypes(builder, downloadCompressTypesOffset);
int root = RequestFLJob.endRequestFLJob(builder);
builder.finish(root);
return builder.sizedByteArray();

View File

@ -147,6 +147,10 @@ public class SyncFLJob {
LOGGER.info(Common.addTag("[startFLJob] startFLJob succeed, curIteration: " + flLiteClient.getIteration()));
updateTryTimePerIter(flLiteClient);
// Copy weights before training.
Map<String, float[]> oldFeatureMap = flLiteClient.getFeatureMap();
localFLParameter.setOldFeatureMap(oldFeatureMap);
// create mask
curStatus = flLiteClient.getFeatureMask();
if (curStatus == FLClientStatus.RESTART) {

View File

@ -26,11 +26,15 @@ import com.mindspore.flclient.model.SessionUtil;
import com.mindspore.flclient.model.Status;
import com.mindspore.flclient.model.TrainLenet;
import com.mindspore.lite.MSTensor;
import com.mindspore.flclient.compression.EncodeExecutor;
import com.mindspore.flclient.compression.CompressWeight;
import mindspore.schema.FeatureMap;
import mindspore.schema.CompressFeatureMap;
import mindspore.schema.RequestUpdateModel;
import mindspore.schema.ResponseCode;
import mindspore.schema.ResponseUpdateModel;
import static mindspore.schema.CompressType.NO_COMPRESS;
import java.util.ArrayList;
import java.util.Date;
@ -208,6 +212,7 @@ public class UpdateModel {
private RequestUpdateModel requestUM;
private FlatBufferBuilder builder;
private int fmOffset = 0;
private int compFmOffset = 0;
private int nameOffset = 0;
private int idOffset = 0;
private int timestampOffset = 0;
@ -215,8 +220,11 @@ public class UpdateModel {
private int sign = 0;
private int indexArrayOffset = 0;
private int iteration = 0;
private byte uploadCompressType = 0;
private float uploadSparseRate = 0.0f;
private EncryptLevel encryptLevel = EncryptLevel.NOT_ENCRYPT;
private float uploadLossOffset = 0.0f;
private int nameVecOffset = 0;
private RequestUpdateModelBuilder(EncryptLevel encryptLevel) {
builder = new FlatBufferBuilder();
@ -294,34 +302,33 @@ public class UpdateModel {
} else {
trainedMap = getFeatureMap();
}
Map<String, List<Float>> featureMaps = new HashMap<>();
long startTime;
long endTime;
switch (encryptLevel) {
case PW_ENCRYPT:
int[] fmOffsetsPW = secureProtocol.pwMaskModel(builder, trainDataSize, trainedMap);
if (fmOffsetsPW == null || fmOffsetsPW.length == 0) {
LOGGER.severe("[Encrypt] the return fmOffsetsPW from <secureProtocol.pwMaskModel> is " +
featureMaps = secureProtocol.pwMaskModel(builder, trainDataSize, trainedMap);
if (featureMaps == null || featureMaps.size() == 0) {
LOGGER.severe("[Encrypt] the return featureMaps from <secureProtocol.pwMaskModel> is " +
"null, please check");
throw new IllegalArgumentException();
}
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsPW);
LOGGER.info(Common.addTag("[Encrypt] pairwise mask model ok!"));
return this;
break;
case DP_ENCRYPT:
startTime = System.currentTimeMillis();
int[] fmOffsetsDP = secureProtocol.dpMaskModel(builder, trainDataSize, trainedMap);
if (fmOffsetsDP == null || fmOffsetsDP.length == 0) {
LOGGER.severe("[Encrypt] the return fmOffsetsDP from <secureProtocol.dpMaskModel> is " +
featureMaps = secureProtocol.dpMaskModel(builder, trainDataSize, trainedMap);
if (featureMaps == null || featureMaps.size() == 0) {
LOGGER.severe("[Encrypt] the return featureMaps from <secureProtocol.dpMaskModel> is " +
"null, please check");
retCode = ResponseCode.RequestError;
status = FLClientStatus.FAILED;
throw new IllegalArgumentException();
}
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsDP);
LOGGER.info(Common.addTag("[Encrypt] DP mask model ok!"));
endTime = System.currentTimeMillis();
LOGGER.info(Common.addTag("[Encrypt] dp time is: " + (endTime - startTime) + "ms"));
return this;
LOGGER.info(Common.addTag("dp time is " + (endTime - startTime) + "ms"));
break;
case SIGNDS:
startTime = System.currentTimeMillis();
// signds alg return indexArray, and package indexArray into flatbuffer.
@ -352,31 +359,104 @@ public class UpdateModel {
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsSignds);
LOGGER.info(Common.addTag("[Encrypt] SignDS mask model ok!"));
endTime = System.currentTimeMillis();
LOGGER.info(Common.addTag("[Encrypt] signds time is: " + (endTime - startTime) + "ms"));
LOGGER.info(Common.addTag("signds time is " + (endTime - startTime) + "ms"));
return this;
case NOT_ENCRYPT:
default:
startTime = System.currentTimeMillis();
int featureSize = updateFeatureName.size();
int[] fmOffsets = new int[featureSize];
for (int i = 0; i < featureSize; i++) {
String key = updateFeatureName.get(i);
float[] data = trainedMap.get(key);
LOGGER.info(Common.addTag("[updateModel build featuresMap] feature name: " + key + " feature " +
"size: " + data.length));
for (int j = 0; j < data.length; j++) {
data[j] = data[j] * trainDataSize;
for (String name : updateFeatureName) {
float[] data = trainedMap.get(name);
List<Float> featureMap = new ArrayList<>();
for (float datum : data) {
featureMap.add(datum * (float) trainDataSize);
}
int featureName = builder.createString(key);
int weight = FeatureMap.createDataVector(builder, data);
int featureMap = FeatureMap.createFeatureMap(builder, featureName, weight);
fmOffsets[i] = featureMap;
featureMaps.put(name, featureMap);
}
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsets);
endTime = System.currentTimeMillis();
LOGGER.info(Common.addTag("[Encrypt] not encrypt time is: " + (endTime - startTime) + "ms"));
return this;
LOGGER.info(Common.addTag("not encrypt time is " + (endTime - startTime) + "ms"));
break;
}
byte uploadCompressType = localFLParameter.getUploadCompressType();
if (uploadCompressType != NO_COMPRESS) {
startTime = System.currentTimeMillis();
this.compFmOffset = buildCompFmOffset(featureMaps, trainDataSize);
this.uploadCompressType = localFLParameter.getUploadCompressType();
this.uploadSparseRate = localFLParameter.getUploadSparseRatio();
this.nameVecOffset = buildNameVecOffset(updateFeatureName);
endTime = System.currentTimeMillis();
LOGGER.info(Common.addTag("compression time is " + (endTime - startTime) + "ms"));
return this;
}
this.fmOffset = buildFmOffset(featureMaps, updateFeatureName);
return this;
}
private int buildCompFmOffset(Map<String, List<Float>> featureMaps, int trainDataSize) {
List<CompressWeight> compressWeights = EncodeExecutor.getInstance().encode(featureMaps, trainDataSize);
if (compressWeights == null || compressWeights.size() == 0) {
LOGGER.severe("[Compression] the return compressWeights from <encodeExecutor.encode> is " +
"null, please check");
retCode = ResponseCode.RequestError;
status = FLClientStatus.FAILED;
throw new IllegalArgumentException();
}
int compFeatureSize = compressWeights.size();
int[] compFmOffsets = new int[compFeatureSize];
int index = 0;
for (CompressWeight compressWeight : compressWeights) {
String weightFullname = compressWeight.getWeightFullname();
List<Byte> compressData = compressWeight.getCompressData();
float minVal = compressWeight.getMinValue();
float maxVal = compressWeight.getMaxValue();
byte[] data = new byte[compressData.size()];
LOGGER.info(Common.addTag("[updateModel build compressWeight] feature name: "
+ weightFullname + ", feature size: " + data.length));
for (int j = 0; j < data.length; j++) {
data[j] = compressData.get(j);
}
int featureName = builder.createString(weightFullname);
int weight = CompressFeatureMap.createCompressDataVector(builder, data);
int featureMap = CompressFeatureMap.createCompressFeatureMap(builder, featureName, weight,
minVal, maxVal);
LOGGER.info(Common.addTag("[Compression]" +
" featureName: " + weightFullname +
", min_val: " + minVal +
", max_val: " + maxVal));
compFmOffsets[index] = featureMap;
index += 1;
}
return RequestUpdateModel.createCompressFeatureMapVector(builder, compFmOffsets);
}
private int buildNameVecOffset(ArrayList<String> updateFeatureName) {
int featureSize = updateFeatureName.size();
int[] nameVecOffsets = new int[featureSize];
for (int i = 0; i < featureSize; i++) {
String key = updateFeatureName.get(i);
int featureName = builder.createString(key);
nameVecOffsets[i] = featureName;
}
return RequestUpdateModel.createNameVecVector(builder, nameVecOffsets);
}
private int buildFmOffset(Map<String, List<Float>> featureMaps, ArrayList<String> updateFeatureName) {
int featureSize = updateFeatureName.size();
int[] fmOffsets = new int[featureSize];
for (int i = 0; i < featureSize; i++) {
String key = updateFeatureName.get(i);
List<Float> featureMap = featureMaps.get(key);
float[] data = new float[featureMap.size()];
LOGGER.info(Common.addTag("[updateModel build featuresMap] feature name: " + key + " feature " +
"size: " + data.length));
for (int j = 0; j < data.length; j++) {
data[j] = featureMap.get(j);
}
int featureName = builder.createString(key);
int weight = FeatureMap.createDataVector(builder, data);
int featureMapOff = FeatureMap.createFeatureMap(builder, featureName, weight);
fmOffsets[i] = featureMapOff;
}
return RequestUpdateModel.createFeatureMapVector(builder, fmOffsets);
}
/**
@ -417,6 +497,10 @@ public class UpdateModel {
RequestUpdateModel.addFlId(this.builder, idOffset);
RequestUpdateModel.addTimestamp(builder, this.timestampOffset);
RequestUpdateModel.addIteration(builder, this.iteration);
RequestUpdateModel.addCompressFeatureMap(builder, this.compFmOffset);
RequestUpdateModel.addUploadCompressType(builder, this.uploadCompressType);
RequestUpdateModel.addUploadSparseRate(builder, this.uploadSparseRate);
RequestUpdateModel.addNameVec(builder, this.nameVecOffset);
RequestUpdateModel.addFeatureMap(builder, this.fmOffset);
RequestUpdateModel.addSignature(builder, this.signDataOffset);
RequestUpdateModel.addUploadLoss(builder, this.uploadLossOffset);

View File

@ -0,0 +1,40 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2022. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient.compression;
import java.util.HashMap;
import java.util.Map;
import static mindspore.schema.CompressType.NO_COMPRESS;
import static mindspore.schema.CompressType.QUANT;
/**
* The compress mod.
*
* @since 2021-12-21
*/
public class CompressMode {
// compress type -> num bits
public static final Map<Byte, Integer> COMPRESS_TYPE_MAP = new HashMap<>();
static {
COMPRESS_TYPE_MAP.put(NO_COMPRESS, -1);
COMPRESS_TYPE_MAP.put(QUANT, 8);
}
}

View File

@ -0,0 +1,83 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2022. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient.compression;
import java.util.List;
/**
* Compress Weight Bean
*
* @since 2021-12-21
*/
public class CompressWeight {
private String weightFullname;
private List<Byte> compressData;
private float minValue;
private float maxValue;
public CompressWeight() {
}
public CompressWeight(String weightFullname, List<Byte> compressData, float minValue, float maxValue) {
this.weightFullname = weightFullname;
this.compressData = compressData;
this.minValue = minValue;
this.maxValue = maxValue;
}
public String getWeightFullname() {
return weightFullname;
}
public void setWeightFullname(String weightFullname) {
this.weightFullname = weightFullname;
}
public List<Byte> getCompressData() {
return compressData;
}
public void setCompressData(List<Byte> compressData) {
this.compressData = compressData;
}
public float getMinValue() {
return minValue;
}
public void setMinValue(float minValue) {
this.minValue = minValue;
}
public float getMaxValue() {
return maxValue;
}
public void setMaxValue(float maxValue) {
this.maxValue = maxValue;
}
@Override
public String toString() {
return "CompressWeight{" +
"weightFullname='" + weightFullname + '\'' +
", compressData=" + compressData +
", minValue=" + minValue +
", maxValue=" + maxValue +
'}';
}
}

View File

@ -0,0 +1,115 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2022. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient.compression;
import com.google.flatbuffers.FlatBufferBuilder;
import com.mindspore.flclient.Common;
import com.mindspore.flclient.StartFLJob;
import mindspore.schema.CompressFeatureMap;
import mindspore.schema.FeatureMap;
import mindspore.schema.CompressType;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
import static mindspore.schema.CompressType.QUANT;
/**
* Compress Executor
*
* @since 2021-12-21
*/
public class DecodeExecutor {
private static final Logger LOGGER = Logger.getLogger(DecodeExecutor.class.toString());
private static volatile DecodeExecutor compressExecutor;
private DecodeExecutor() {}
public static DecodeExecutor getInstance() {
if (compressExecutor == null) {
synchronized (DecodeExecutor.class) {
if (compressExecutor == null) {
compressExecutor = new DecodeExecutor();
}
}
}
return compressExecutor;
}
public List<FeatureMap> deCompressWeight(byte compressType, List<CompressFeatureMap> compressFeatureMapList) {
if (!CompressMode.COMPRESS_TYPE_MAP.containsKey(compressType)) {
return new ArrayList<>();
}
LOGGER.info(Common.addTag("[deCompressWeight] create " + CompressType.name(compressType) + " feature map."));
int num_bits = CompressMode.COMPRESS_TYPE_MAP.get(compressType);
if (compressType == QUANT) {
return deCompressQuantMinMax(compressFeatureMapList, num_bits);
}
return new ArrayList<>();
}
private List<FeatureMap> deCompressQuantMinMax(List<CompressFeatureMap> compressFeatureMapList, int num_bits) {
float temp1 = (float) (Math.pow(2, num_bits) - 1);
float temp2 = (float) Math.pow(2, num_bits - 1);
Map<String, float[]> deCompressFeatureMaps = new HashMap<>();
int compressFeatureMapLength = compressFeatureMapList.size();
for (int i = 0; i < compressFeatureMapLength; i++) {
CompressFeatureMap compressFeatureMap = compressFeatureMapList.get(i);
String weightName = compressFeatureMap.weightFullname();
int compressDataLength = compressFeatureMap.compressDataLength();
List<Byte> compressWeightList = new ArrayList<>();
for (int j = 0; j < compressDataLength; j++) {
compressWeightList.add(compressFeatureMap.compressData(j));
}
float minVal = compressFeatureMap.minVal();
float maxVal = compressFeatureMap.maxVal();
float scale_value = (float) ((maxVal - minVal) / temp1 + 1e-10);
float[] params = new float[compressWeightList.size()];
for (int j = 0; j < params.length; j++) {
float val = (compressWeightList.get(j).intValue() + temp2) * scale_value + minVal;
params[j] = val;
}
deCompressFeatureMaps.put(weightName, params);
}
List<FeatureMap> featureMaps = new ArrayList<>();
for (String weightName : deCompressFeatureMaps.keySet()) {
FlatBufferBuilder builder = new FlatBufferBuilder(0);
int weightFullnameOffset = builder.createString(weightName);
float[] data = deCompressFeatureMaps.get(weightName);
int dataOffset = FeatureMap.createDataVector(builder, data);
FeatureMap.startFeatureMap(builder);
FeatureMap.addWeightFullname(builder, weightFullnameOffset);
FeatureMap.addData(builder, dataOffset);
int orc = FeatureMap.endFeatureMap(builder);
builder.finish(orc);
ByteBuffer buf = builder.dataBuffer();
FeatureMap featureMap = FeatureMap.getRootAsFeatureMap(buf);
featureMaps.add(featureMap);
}
return featureMaps;
}
}

View File

@ -0,0 +1,167 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2022. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient.compression;
import com.mindspore.flclient.LocalFLParameter;
import static mindspore.schema.CompressType.DIFF_SPARSE_QUANT;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.PriorityQueue;
/**
* Encode Executor
*
* @since 2021-12-21
*/
public class EncodeExecutor {
private final LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
private static volatile EncodeExecutor encodeExecutor;
private EncodeExecutor() {}
public static EncodeExecutor getInstance() {
if (encodeExecutor == null) {
synchronized (EncodeExecutor.class) {
if (encodeExecutor == null) {
encodeExecutor = new EncodeExecutor();
}
}
}
return encodeExecutor;
}
private static final int multiplier = 2147483647;
private static final double increment = 4294967294.0;
private static final int modulo = 48271;
private List<Integer> constructMaskArray(int paramNum) {
int seed = localFLParameter.getSeed();
float uploadSparseRatio = localFLParameter.getUploadSparseRatio();
List<Integer> maskArray = new ArrayList<>();
int retain_num = (int) ((float) (paramNum) * uploadSparseRatio);
for (int i = 0; i < retain_num; ++i) {
maskArray.add(1);
}
for (int i = retain_num; i < paramNum; ++i) {
maskArray.add(0);
}
seed = ((seed + multiplier) * modulo) % multiplier;
for (int i = 0; i < paramNum; ++i) {
// generate random number in (0, 1)
double rand = (double)(seed) / increment + 0.5;
// update seed
seed = (seed * modulo) % multiplier;
int j = (int)(rand * (double)(paramNum - i)) + i;
int temp = maskArray.get(i);
maskArray.set(i, maskArray.get(j));
maskArray.set(j, temp);
}
return maskArray;
}
public List<CompressWeight> enDiffSparseQuant(Map<String, List<Float>> featureMaps, int numBits,
int trainDataSize) {
List<CompressWeight> compressWeights = new ArrayList<>();
// difference encode
Map<String, float[]> oldFeatureMap = localFLParameter.getOldFeatureMap();
Map<String, List<Float>> diffFeatureMaps = new HashMap<>();
for (String featureMapName : featureMaps.keySet()) {
List<Float> diffs = new ArrayList<>();
List<Float> featureMap = featureMaps.get(featureMapName);
float[] dataBeforeTrain = oldFeatureMap.get(featureMapName);
int length = dataBeforeTrain.length;
for (int i = 0; i < length; ++i) {
float diff = featureMap.get(i) - dataBeforeTrain[i] * (float) trainDataSize;
diffs.add(diff);
}
diffFeatureMaps.put(featureMapName, diffs);
}
// sparse encode
int paramNum = 0;
for (String featureMapName : diffFeatureMaps.keySet()) {
int weightSize = diffFeatureMaps.get(featureMapName).size();
paramNum += weightSize;
}
List<Integer> maskArray = constructMaskArray(paramNum);
Map<String, List<Float>> sparseFeatureMaps = new HashMap<>();
int index = 0;
for (String featureMapName : diffFeatureMaps.keySet()) {
List<Float> sparseFeatureMap = new ArrayList<>();
List<Float> Weight = diffFeatureMaps.get(featureMapName);
for (Float dataValue : Weight) {
if (maskArray.get(index) == 1) {
sparseFeatureMap.add(dataValue);
}
index += 1;
}
sparseFeatureMaps.put(featureMapName, sparseFeatureMap);
}
// quant encode
float temp1 = (float) (1 << numBits) - 1.0f;
float temp2 = (float) (1 << (numBits - 1));
for (String featureMapName : sparseFeatureMaps.keySet()) {
CompressWeight compressWeight = new CompressWeight();
compressWeight.setWeightFullname(featureMapName);
List<Float> sparseFeatureMap = sparseFeatureMaps.get(featureMapName);
// get min and max value
Float minVal = Float.MAX_VALUE;
float maxVal = -minVal;
for (Float value : sparseFeatureMap) {
if (value < minVal) {
minVal = value;
}
if (value > maxVal) {
maxVal = value;
}
}
compressWeight.setMinValue(minVal);
compressWeight.setMaxValue(maxVal);
float scale_value = (maxVal - minVal) / temp1 + 1e-10f;
List<Byte> compressData = new ArrayList<>();
for (Float aFloat : sparseFeatureMap) {
compressData.add((byte) (Math.round((aFloat - minVal) / scale_value - temp2)));
}
compressWeight.setCompressData(compressData);
compressWeights.add(compressWeight);
}
return compressWeights;
}
public List<CompressWeight> encode(Map<String, List<Float>> featureMaps, int trainDataSize) {
byte uploadCompressType = localFLParameter.getUploadCompressType();
if (uploadCompressType == DIFF_SPARSE_QUANT) {
return enDiffSparseQuant(featureMaps, 8, trainDataSize);
}
throw new IllegalArgumentException();
}
}

View File

@ -15,8 +15,9 @@
"""Context for parameter server training mode"""
import os
from mindspore._checkparam import Validator
from mindspore._checkparam import Validator, Rel
from mindspore._c_expression import PSContext
from mindspore import log as logger
_ps_context = None
@ -79,6 +80,9 @@ _set_ps_context_func_map = {
"sign_global_lr": ps_context().set_sign_global_lr,
"sign_dim_out": ps_context().set_sign_dim_out,
"checkpoint_dir": ps_context().set_checkpoint_dir,
"upload_compress_type": ps_context().set_upload_compress_type,
"upload_sparse_rate": ps_context().set_upload_sparse_rate,
"download_compress_type": ps_context().set_download_compress_type,
}
_get_ps_context_func_map = {
@ -126,7 +130,10 @@ _get_ps_context_func_map = {
"sign_thr_ratio": ps_context().sign_thr_ratio,
"sign_global_lr": ps_context().sign_global_lr,
"sign_dim_out": ps_context().sign_dim_out,
"checkpoint_dir": ps_context().checkpoint_dir
"checkpoint_dir": ps_context().checkpoint_dir,
"upload_compress_type": ps_context().upload_compress_type,
"upload_sparse_rate": ps_context().upload_sparse_rate,
"download_compress_type": ps_context().download_compress_type,
}
_check_positive_int_keys = ["server_num", "scheduler_port", "fl_server_port",
@ -140,6 +147,15 @@ _check_positive_float_keys = ["update_model_ratio", "client_learning_rate"]
_check_port_keys = ["scheduler_port", "fl_server_port"]
_check_string_keys = {
"upload_compress_type": ["NO_COMPRESS", "DIFF_SPARSE_QUANT"],
"download_compress_type": ["NO_COMPRESS", "QUANT"],
}
_check_float_range_keys = {
"upload_sparse_rate": {"lower_limit": 0.0, "upper_limit": 1.0, "rel": Rel.INC_RIGHT},
}
def _get_ps_mode_rank():
ps_rank = ps_context().ps_rank_id()
if ps_rank == -1:
@ -183,6 +199,7 @@ def _set_ps_context(**kwargs):
Examples:
>>> context.set_ps_context(enable_ps=True, enable_ssl=True, client_password='123456', server_password='123456')
"""
kwargs = _check_conflict_value(kwargs)
for key, value in kwargs.items():
if key not in _set_ps_context_func_map:
raise ValueError("Set PS context keyword %s is not recognized!" % key)
@ -287,6 +304,31 @@ def _check_value(key, value):
if key in _check_positive_float_keys:
Validator.check_positive_float(value, key)
if key in _check_string_keys:
try:
string_keys = _check_string_keys[key]
Validator.check_string(value, string_keys)
except KeyError:
pass
if key in _check_float_range_keys:
try:
range_keys = _check_float_range_keys[key]
Validator.check_float_range(value, **range_keys)
except KeyError:
pass
if key in _check_port_keys:
if value < 1 or value > 65535:
raise ValueError("The range of %s must be 1 to 65535, but got %d." % (key, value))
def _check_conflict_value(kwargs):
if "upload_compress_type" in kwargs and " encrypt_type" in kwargs:
if kwargs["upload_compress_type"] != "NO_COMPRESS" and kwargs["encrypt_type"] in ("SIGNDS", "PW_ENCRYPT"):
logger.warning("The '{}' and '{}' are conflicted, and in '{}' mode the"
" 'upload_compress_type' will be 'NO_COMPRESS'".format(kwargs["encrypt_type"],
kwargs["upload_compress_type"],
kwargs["encrypt_type"]))
kwargs["upload_compress_type"] = "NO_COMPRESS"
return kwargs

View File

@ -47,6 +47,16 @@ table FeatureMap{
weight_fullname:string;
data:[float];
}
enum CompressType:byte {NO_COMPRESS = 0, DIFF_SPARSE_QUANT = 1, QUANT = 2}
table CompressFeatureMap{
weight_fullname:string;
compress_data:[int8];
min_val:float;
max_val:float;
}
table RequestFLJob{
fl_name:string;
fl_id:string;
@ -58,6 +68,7 @@ table RequestFLJob{
equip_cert:string;
equip_ca_cert:string;
root_cert:string;
download_compress_types:[CompressType];
}
table ResponseFLJob {
retcode:int;
@ -68,6 +79,10 @@ table ResponseFLJob {
fl_plan_config:FLPlan;
feature_map:[FeatureMap];
timestamp:string;
upload_compress_type:CompressType;
upload_sparse_rate:float;
download_compress_type:CompressType;
compress_feature_map:[CompressFeatureMap];
}
table FLPlan {
@ -94,6 +109,10 @@ table RequestUpdateModel{
upload_loss:float;
sign:int;
index_array:[int];
compress_feature_map:[CompressFeatureMap];
upload_compress_type:CompressType;
upload_sparse_rate:float;
name_vec:[string];
}
table ResponseUpdateModel{
@ -132,6 +151,7 @@ table RequestGetModel{
fl_name:string;
iteration:int;
timestamp:string;
download_compress_types:[CompressType];
}
table ResponseGetModel{
retcode:int;
@ -139,6 +159,8 @@ table ResponseGetModel{
iteration:int;
feature_map:[FeatureMap];
timestamp:string;
download_compress_type:CompressType;
compress_feature_map:[CompressFeatureMap];
}
table RequestAsyncGetModel{