forked from mindspore-Ecosystem/mindspore
fl compression
This commit is contained in:
parent
0373d2f915
commit
ce17db99a6
|
@ -51,6 +51,8 @@ if(NOT ENABLE_CPU OR WIN32)
|
|||
list(REMOVE_ITEM _FL_SRC_FILES "armour/cipher/cipher_reconstruct.cc")
|
||||
list(REMOVE_ITEM _FL_SRC_FILES "armour/cipher/cipher_shares.cc")
|
||||
list(REMOVE_ITEM _FL_SRC_FILES "armour/cipher/cipher_unmask.cc")
|
||||
list(REMOVE_ITEM _FL_SRC_FILES "compression/decode_executor.cc")
|
||||
list(REMOVE_ITEM _FL_SRC_FILES "compression/encode_executor.cc")
|
||||
endif()
|
||||
|
||||
if(CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
||||
|
|
|
@ -0,0 +1,150 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "fl/compression/decode_executor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace fl {
|
||||
namespace compression {
|
||||
std::vector<int> DecodeExecutor::ConstructMaskArray(int seed, float upload_sparse_rate, size_t param_num) {
|
||||
static int multiplier = 2147483647;
|
||||
static double increment = 4294967294.0;
|
||||
static int modulo = 48271;
|
||||
size_t retain_num = size_t(static_cast<float>(param_num) * upload_sparse_rate);
|
||||
if (retain_num == 0) {
|
||||
MS_LOG(WARNING) << "The retain_num is 0, and upload_sparse_rate is too small.";
|
||||
}
|
||||
std::vector<int> mask_array(param_num, 0);
|
||||
for (size_t i = 0; i < retain_num; ++i) {
|
||||
mask_array[i] = 1;
|
||||
}
|
||||
|
||||
seed = ((seed + multiplier) * modulo) % multiplier;
|
||||
for (size_t i = 0; i < param_num; ++i) {
|
||||
// generate random number in (0, 1)
|
||||
double rand = static_cast<double>(seed) / increment + 0.5;
|
||||
// update seed
|
||||
seed = (seed * modulo) % multiplier;
|
||||
size_t j = size_t(rand * static_cast<double>(param_num - i)) + i;
|
||||
int temp = mask_array[i];
|
||||
mask_array[i] = mask_array[j];
|
||||
mask_array[j] = temp;
|
||||
}
|
||||
return mask_array;
|
||||
}
|
||||
|
||||
bool DecodeExecutor::DeQuantSparseDiff(std::map<std::string, std::vector<float>> *weight_map,
|
||||
const std::vector<CompressFeatureMap> &compress_feature_maps, size_t num_bits,
|
||||
float upload_sparse_rate, int seed, const std::vector<std::string> &name_vec,
|
||||
size_t data_size) {
|
||||
std::vector<std::vector<float>> decompress_feature_maps;
|
||||
|
||||
// origin parameters
|
||||
std::vector<size_t> shape_vec;
|
||||
size_t param_num = 0;
|
||||
const auto &iter_to_model = mindspore::fl::server::ModelStore::GetInstance().iteration_to_model();
|
||||
size_t latest_iter_num = iter_to_model.rbegin()->first;
|
||||
std::map<std::string, AddressPtr> feature_maps =
|
||||
mindspore::fl::server::ModelStore::GetInstance().GetModelByIterNum(latest_iter_num);
|
||||
// get shape vector and number of upload parameters
|
||||
for (const auto &name : name_vec) {
|
||||
size_t shape = feature_maps[name]->size / sizeof(float);
|
||||
shape_vec.emplace_back(shape);
|
||||
param_num += shape;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Compression get last weights success!";
|
||||
|
||||
// quant decode
|
||||
auto temp1 = static_cast<float>(1 << num_bits) - 1.0f;
|
||||
auto temp2 = static_cast<float>(1 << (num_bits - 1));
|
||||
std::vector<float> de_min_max_feature_map;
|
||||
for (auto compress_feature_map : compress_feature_maps) {
|
||||
float min_val = compress_feature_map.min_val;
|
||||
float max_val = compress_feature_map.max_val;
|
||||
float scale_val = static_cast<float>(max_val - min_val) / temp1 + 1e-10f;
|
||||
size_t size = compress_feature_map.compress_data.size();
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
de_min_max_feature_map.emplace_back(
|
||||
(static_cast<float>(compress_feature_map.compress_data[i]) + temp2) * scale_val + min_val);
|
||||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << "Compression quant decode success!";
|
||||
|
||||
// sparse decode
|
||||
std::vector<int> mask_array = ConstructMaskArray(seed, upload_sparse_rate, param_num);
|
||||
size_t index = 0;
|
||||
size_t de_min_max_feature_map_index = 0;
|
||||
for (const auto &shape : shape_vec) {
|
||||
std::vector<float> feature_map(shape);
|
||||
for (size_t i = 0; i < shape; ++i) {
|
||||
if (index >= mask_array.size()) {
|
||||
MS_LOG(WARNING) << "The mask_array and parameter shape is not matched.";
|
||||
return false;
|
||||
}
|
||||
if (mask_array[index] == 1) {
|
||||
if (de_min_max_feature_map_index >= de_min_max_feature_map.size()) {
|
||||
MS_LOG(WARNING) << "The number of upload parameters is too small.";
|
||||
return false;
|
||||
}
|
||||
feature_map[i] = de_min_max_feature_map[de_min_max_feature_map_index];
|
||||
de_min_max_feature_map_index += 1;
|
||||
} else {
|
||||
feature_map[i] = 0.0f;
|
||||
}
|
||||
index += 1;
|
||||
}
|
||||
decompress_feature_maps.emplace_back(feature_map);
|
||||
}
|
||||
MS_LOG(DEBUG) << "Compression sparse decode success!";
|
||||
|
||||
// difference decode
|
||||
for (size_t i = 0; i < decompress_feature_maps.size(); ++i) {
|
||||
size_t feature_size = decompress_feature_maps[i].size();
|
||||
std::string name = name_vec[i];
|
||||
float *weight_data = reinterpret_cast<float *>(feature_maps[name]->addr);
|
||||
auto &weight_item = (*weight_map)[name];
|
||||
weight_item.resize(feature_size);
|
||||
for (size_t j = 0; j < feature_size; ++j) {
|
||||
weight_item[j] = decompress_feature_maps[i][j] + data_size * weight_data[j];
|
||||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << "Compression difference decode success!";
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool DecodeExecutor::Decode(std::map<std::string, std::vector<float>> *weight_map,
|
||||
const std::vector<CompressFeatureMap> &compress_feature_maps,
|
||||
schema::CompressType upload_compress_type, float upload_sparse_rate, int seed,
|
||||
const std::vector<std::string> &name_vec, size_t data_size) {
|
||||
if (upload_compress_type == schema::CompressType_DIFF_SPARSE_QUANT) {
|
||||
return DeQuantSparseDiff(weight_map, compress_feature_maps, 8, upload_sparse_rate, seed, name_vec, data_size);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
schema::CompressType DecodeExecutor::GetCompressType(schema::CompressType upload_compress_type) {
|
||||
if (upload_compress_type == schema::CompressType_DIFF_SPARSE_QUANT) {
|
||||
MS_LOG(DEBUG) << "This upload compress type is DIFF_SPARSE_QUANT.";
|
||||
return schema::CompressType_DIFF_SPARSE_QUANT;
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "This upload compress type is NO_COMPRESS.";
|
||||
return schema::CompressType_NO_COMPRESS;
|
||||
}
|
||||
} // namespace compression
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,74 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FL_COMPRESSION_DECODE_EXECUTOR_H_
|
||||
#define MINDSPORE_CCSRC_FL_COMPRESSION_DECODE_EXECUTOR_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <algorithm>
|
||||
#include <regex>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include "proto/comm.pb.h"
|
||||
#include "schema/fl_job_generated.h"
|
||||
#include "schema/cipher_generated.h"
|
||||
#include "fl/server/model_store.h"
|
||||
#include "fl/server/common.h"
|
||||
#include "ps/ps_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace fl {
|
||||
namespace compression {
|
||||
struct CompressFeatureMap {
|
||||
std::string weight_fullname;
|
||||
std::vector<int8_t> compress_data;
|
||||
float min_val;
|
||||
float max_val;
|
||||
};
|
||||
|
||||
class DecodeExecutor {
|
||||
public:
|
||||
static DecodeExecutor &GetInstance() {
|
||||
static DecodeExecutor instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
// construct mask array for random sparse
|
||||
std::vector<int> ConstructMaskArray(int seed, float upload_sparse_rate, size_t param_num);
|
||||
|
||||
// decode min_max quantization and random sparse and parameter difference
|
||||
bool DeQuantSparseDiff(std::map<std::string, std::vector<float>> *weight_map,
|
||||
const std::vector<CompressFeatureMap> &compress_feature_maps, size_t num_bits,
|
||||
float upload_sparse_rate, int seed, const std::vector<std::string> &name_vec,
|
||||
size_t data_size);
|
||||
|
||||
// decode
|
||||
bool Decode(std::map<std::string, std::vector<float>> *weight_map,
|
||||
const std::vector<CompressFeatureMap> &compress_feature_maps, schema::CompressType upload_compress_type,
|
||||
float upload_sparse_rate, int seed, const std::vector<std::string> &name_vec, size_t data_size);
|
||||
|
||||
schema::CompressType GetCompressType(schema::CompressType upload_compress_type);
|
||||
};
|
||||
} // namespace compression
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_FL_COMPRESSION_DECODE_EXECUTOR_H_
|
|
@ -0,0 +1,102 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "fl/compression/encode_executor.h"
|
||||
|
||||
#include <arpa/inet.h>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <algorithm>
|
||||
#include <regex>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "fl/server/common.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace fl {
|
||||
namespace compression {
|
||||
bool CompressExecutor::EnableCompressWeight(const schema::CompressType compressType) {
|
||||
return kCompressTypeMap.count(compressType) > 0;
|
||||
}
|
||||
|
||||
bool CompressExecutor::construct_compress_weight(std::map<std::string, CompressWeight> *compressWeights,
|
||||
std::map<std::string, std::vector<float>> feature_maps,
|
||||
const schema::CompressType compressType) {
|
||||
if (compressType == schema::CompressType_QUANT) {
|
||||
return quant_min_max(compressWeights, feature_maps, kCompressTypeMap.at(compressType));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool CompressExecutor::quant_min_max(std::map<std::string, CompressWeight> *compressWeights,
|
||||
std::map<std::string, std::vector<float>> feature_maps, size_t num_bits) {
|
||||
auto temp1 = static_cast<float>(1 << num_bits) - 1.0f;
|
||||
auto temp2 = static_cast<float>(1 << (num_bits - 1));
|
||||
for (const auto &feature_map : feature_maps) {
|
||||
std::string weight_name = feature_map.first;
|
||||
float min_value = 1e10f;
|
||||
float max_value = -min_value;
|
||||
for (const auto &feature : feature_map.second) {
|
||||
if (feature > max_value) {
|
||||
max_value = feature;
|
||||
}
|
||||
if (feature < min_value) {
|
||||
min_value = feature;
|
||||
}
|
||||
}
|
||||
float scale_value = (max_value - min_value) / temp1 + 1e-10f;
|
||||
size_t size = feature_map.second.size();
|
||||
if (size == 0) {
|
||||
MS_LOG(WARNING) << "The size of parameters is zero.";
|
||||
return false;
|
||||
}
|
||||
CompressWeight compressWeight;
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
auto round_data = round((feature_map.second[i] - min_value) / scale_value - temp2);
|
||||
// bit pack can be implemented here in the future
|
||||
auto int8_data = int8_t(round_data);
|
||||
compressWeight.compress_data.emplace_back(int8_data);
|
||||
}
|
||||
compressWeight.min_val = min_value;
|
||||
compressWeight.max_val = max_value;
|
||||
compressWeight.compress_data_len = size;
|
||||
|
||||
(*compressWeights)[weight_name] = compressWeight;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
schema::CompressType CompressExecutor::GetCompressType(const flatbuffers::Vector<int8_t> *download_compress_types) {
|
||||
schema::CompressType compressType = schema::CompressType_NO_COMPRESS;
|
||||
if (download_compress_types == nullptr) {
|
||||
MS_LOG(DEBUG) << "The client does not support current download compress type.";
|
||||
} else {
|
||||
for (size_t i = 0; i < download_compress_types->size(); ++i) {
|
||||
auto download_compress_type = download_compress_types->Get(i);
|
||||
if (download_compress_type == schema::CompressType_QUANT) {
|
||||
compressType = schema::CompressType_QUANT;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return compressType;
|
||||
}
|
||||
} // namespace compression
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,68 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FL_COMPRESSION_ENCODE_EXECUTOR_H_
|
||||
#define MINDSPORE_CCSRC_FL_COMPRESSION_ENCODE_EXECUTOR_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include "proto/comm.pb.h"
|
||||
#include "schema/fl_job_generated.h"
|
||||
#include "schema/cipher_generated.h"
|
||||
#include "fl/armour/secure_protocol/key_agreement.h"
|
||||
#include "ps/ps_context.h"
|
||||
#include "ps/core/worker_node.h"
|
||||
#include "ps/core/cluster_metadata.h"
|
||||
#include "ps/core/communicator/tcp_communicator.h"
|
||||
#include "fl/server/common.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace fl {
|
||||
namespace compression {
|
||||
// compress type map: schema::CompressType -> num bits
|
||||
const std::map<schema::CompressType, size_t> kCompressTypeMap = {{schema::CompressType_QUANT, 8}};
|
||||
|
||||
struct CompressWeight {
|
||||
std::vector<int8_t> compress_data;
|
||||
size_t compress_data_len;
|
||||
float min_val;
|
||||
float max_val;
|
||||
};
|
||||
|
||||
class CompressExecutor {
|
||||
public:
|
||||
static CompressExecutor &GetInstance() {
|
||||
static CompressExecutor instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
bool EnableCompressWeight(const schema::CompressType compressType);
|
||||
|
||||
bool construct_compress_weight(std::map<std::string, CompressWeight> *compressWeights,
|
||||
std::map<std::string, std::vector<float>> feature_maps,
|
||||
const schema::CompressType compressType);
|
||||
|
||||
bool quant_min_max(std::map<std::string, CompressWeight> *compressWeights,
|
||||
std::map<std::string, std::vector<float>> feature_maps, size_t num_bits);
|
||||
|
||||
schema::CompressType GetCompressType(const flatbuffers::Vector<int8_t> *download_compress_types);
|
||||
};
|
||||
} // namespace compression
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_FL_COMPRESSION_ENCODE_EXECUTOR_H_
|
|
@ -149,6 +149,11 @@ constexpr auto kUpdateModelRejectClientNum = "updateModelRejectClientNum";
|
|||
constexpr auto kGetModelTotalClientNum = "getModelTotalClientNum";
|
||||
constexpr auto kGetModelAcceptClientNum = "getModelAcceptClientNum";
|
||||
constexpr auto kGetModelRejectClientNum = "getModelRejectClientNum";
|
||||
constexpr auto kMinVal = "min_val";
|
||||
constexpr auto kMaxVal = "max_val";
|
||||
constexpr auto kQuant = "QUANT";
|
||||
constexpr auto kDiffSparseQuant = "DIFF_SPARSE_QUANT";
|
||||
constexpr auto kNoCompress = "NO_COMPRESS";
|
||||
|
||||
// OptimParamNameToIndex represents every inputs/workspace/outputs parameter's offset when an optimizer kernel is
|
||||
// launched.
|
||||
|
|
|
@ -588,6 +588,7 @@ void Iteration::Next(bool is_iteration_valid, const std::string &reason) {
|
|||
|
||||
if (LocalMetaStore::GetInstance().verifyAggregationFeatureMap(feature_map)) {
|
||||
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
|
||||
ModelStore::GetInstance().StoreCompressModelByIterNum(iteration_num_, model);
|
||||
iteration_result_ = IterationResult::kSuccess;
|
||||
MS_LOG(INFO) << "Iteration " << iteration_num_ << " is successfully finished.";
|
||||
} else {
|
||||
|
@ -599,6 +600,7 @@ void Iteration::Next(bool is_iteration_valid, const std::string &reason) {
|
|||
size_t latest_iter_num = iter_to_model.rbegin()->first;
|
||||
const auto &model = ModelStore::GetInstance().GetModelByIterNum(latest_iter_num);
|
||||
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
|
||||
ModelStore::GetInstance().StoreCompressModelByIterNum(iteration_num_, model);
|
||||
iteration_result_ = IterationResult::kFail;
|
||||
MS_LOG(WARNING) << "Iteration " << iteration_num_ << " is invalid. Reason: " << reason;
|
||||
}
|
||||
|
|
|
@ -92,7 +92,7 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req,
|
|||
return;
|
||||
}
|
||||
auto next_req_time = LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp);
|
||||
std::map<std::string, AddressPtr> feature_maps;
|
||||
std::map<std::string, AddressPtr> feature_maps = {};
|
||||
size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num();
|
||||
size_t get_model_iter = IntToSize(get_model_req->iteration());
|
||||
const auto &iter_to_model = ModelStore::GetInstance().iteration_to_model();
|
||||
|
@ -110,6 +110,7 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req,
|
|||
SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return;
|
||||
}
|
||||
|
||||
IncreaseAcceptClientNum();
|
||||
auto real_get_model_iter = get_model_iter;
|
||||
if (iter_to_model.count(get_model_iter) == 0) {
|
||||
|
@ -118,12 +119,37 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req,
|
|||
<< " is invalid. Current iteration is " << std::to_string(current_iter);
|
||||
real_get_model_iter = latest_iter_num;
|
||||
}
|
||||
auto cache = ModelStore::GetInstance().GetModelResponseCache(name_, current_iter, real_get_model_iter);
|
||||
auto download_compress_types = get_model_req->download_compress_types();
|
||||
schema::CompressType compressType =
|
||||
mindspore::fl::compression::CompressExecutor::GetInstance().GetCompressType(download_compress_types);
|
||||
std::string compress_type;
|
||||
if (compressType == schema::CompressType_QUANT) {
|
||||
compress_type = kQuant;
|
||||
} else {
|
||||
compress_type = kNoCompress;
|
||||
}
|
||||
auto cache = ModelStore::GetInstance().GetModelResponseCache(name_, current_iter, real_get_model_iter, compress_type);
|
||||
if (cache == nullptr) {
|
||||
feature_maps = ModelStore::GetInstance().GetModelByIterNum(real_get_model_iter);
|
||||
// Only download compress weights if client support.
|
||||
std::map<std::string, AddressPtr> compress_feature_maps = {};
|
||||
if (compressType == schema::CompressType_NO_COMPRESS) {
|
||||
feature_maps = ModelStore::GetInstance().GetModelByIterNum(real_get_model_iter);
|
||||
} else {
|
||||
auto compressExecutor = mindspore::fl::compression::CompressExecutor::GetInstance();
|
||||
if (compressExecutor.EnableCompressWeight(compressType)) {
|
||||
const auto &iter_to_compress_model = ModelStore::GetInstance().iteration_to_compress_model();
|
||||
if (iter_to_compress_model.count(get_model_iter) == 0) {
|
||||
MS_LOG(DEBUG) << "The iteration of GetCompressModel request " << std::to_string(get_model_iter)
|
||||
<< " is invalid. Current iteration is " << std::to_string(current_iter);
|
||||
compress_feature_maps = ModelStore::GetInstance().GetCompressModelByIterNum(latest_iter_num, compressType);
|
||||
} else {
|
||||
compress_feature_maps = ModelStore::GetInstance().GetCompressModelByIterNum(get_model_iter, compressType);
|
||||
}
|
||||
}
|
||||
}
|
||||
BuildGetModelRsp(fbb, schema::ResponseCode_SUCCEED, "Get model for iteration " + std::to_string(get_model_iter),
|
||||
current_iter, feature_maps, std::to_string(next_req_time));
|
||||
cache = ModelStore::GetInstance().StoreModelResponseCache(name_, current_iter, real_get_model_iter,
|
||||
current_iter, feature_maps, std::to_string(next_req_time), compressType, compress_feature_maps);
|
||||
cache = ModelStore::GetInstance().StoreModelResponseCache(name_, current_iter, real_get_model_iter, compress_type,
|
||||
fbb->GetBufferPointer(), fbb->GetSize());
|
||||
if (cache == nullptr) {
|
||||
SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
|
@ -131,7 +157,7 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req,
|
|||
}
|
||||
}
|
||||
SendResponseMsgInference(message, cache->data(), cache->size(), ModelStore::GetInstance().RelModelResponseCache);
|
||||
MS_LOG(DEBUG) << "GetModel last iteratin is valid or not: " << Iteration::GetInstance().is_last_iteration_valid()
|
||||
MS_LOG(DEBUG) << "GetModel last iteration is valid or not: " << Iteration::GetInstance().is_last_iteration_valid()
|
||||
<< ", next request time is " << next_req_time << ", current iteration is " << current_iter;
|
||||
return;
|
||||
}
|
||||
|
@ -139,7 +165,8 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req,
|
|||
void GetModelKernel::BuildGetModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
|
||||
const std::string &reason, const size_t iter,
|
||||
const std::map<std::string, AddressPtr> &feature_maps,
|
||||
const std::string ×tamp) {
|
||||
const std::string ×tamp, const schema::CompressType &compressType,
|
||||
const std::map<std::string, AddressPtr> &compress_feature_maps) {
|
||||
if (fbb == nullptr) {
|
||||
MS_LOG(ERROR) << "Input fbb is nullptr.";
|
||||
return;
|
||||
|
@ -156,12 +183,40 @@ void GetModelKernel::BuildGetModelRsp(const std::shared_ptr<FBBuilder> &fbb, con
|
|||
}
|
||||
auto fbs_feature_maps_vector = fbb->CreateVector(fbs_feature_maps);
|
||||
|
||||
// construct compress feature maps with fbs
|
||||
std::vector<flatbuffers::Offset<schema::CompressFeatureMap>> fbs_compress_feature_maps;
|
||||
for (const auto &compress_feature_map : compress_feature_maps) {
|
||||
if (compress_feature_map.first.find(kMinVal) != string::npos ||
|
||||
compress_feature_map.first.find(kMaxVal) != string::npos) {
|
||||
continue;
|
||||
}
|
||||
auto fbs_compress_weight_fullname = fbb->CreateString(compress_feature_map.first);
|
||||
auto fbs_compress_weight_data = fbb->CreateVector(reinterpret_cast<int8_t *>(compress_feature_map.second->addr),
|
||||
compress_feature_map.second->size / sizeof(int8_t));
|
||||
|
||||
const std::string min_val_name = compress_feature_map.first + "." + kMinVal;
|
||||
const std::string max_val_name = compress_feature_map.first + "." + kMaxVal;
|
||||
|
||||
const AddressPtr min_val_ptr = compress_feature_maps.at(min_val_name);
|
||||
const AddressPtr max_val_ptr = compress_feature_maps.at(max_val_name);
|
||||
|
||||
float *fbs_min_val_ptr = reinterpret_cast<float *>(min_val_ptr->addr);
|
||||
float *fbs_max_val_ptr = reinterpret_cast<float *>(max_val_ptr->addr);
|
||||
auto fbs_compress_feature_map = schema::CreateCompressFeatureMap(
|
||||
*(fbb.get()), fbs_compress_weight_fullname, fbs_compress_weight_data, *fbs_min_val_ptr, *fbs_max_val_ptr);
|
||||
|
||||
fbs_compress_feature_maps.push_back(fbs_compress_feature_map);
|
||||
}
|
||||
auto fbs_compress_feature_maps_vector = fbb->CreateVector(fbs_compress_feature_maps);
|
||||
|
||||
schema::ResponseGetModelBuilder rsp_get_model_builder(*(fbb.get()));
|
||||
rsp_get_model_builder.add_retcode(static_cast<int>(retcode));
|
||||
rsp_get_model_builder.add_reason(fbs_reason);
|
||||
rsp_get_model_builder.add_iteration(static_cast<int>(iter));
|
||||
rsp_get_model_builder.add_feature_map(fbs_feature_maps_vector);
|
||||
rsp_get_model_builder.add_timestamp(fbs_timestamp);
|
||||
rsp_get_model_builder.add_download_compress_type(compressType);
|
||||
rsp_get_model_builder.add_compress_feature_map(fbs_compress_feature_maps_vector);
|
||||
auto rsp_get_model = rsp_get_model_builder.Finish();
|
||||
fbb->Finish(rsp_get_model);
|
||||
return;
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "fl/server/executor.h"
|
||||
#include "fl/server/kernel/round/round_kernel.h"
|
||||
#include "fl/server/kernel/round/round_kernel_factory.h"
|
||||
#include "fl/compression/encode_executor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace fl {
|
||||
|
@ -44,7 +45,9 @@ class GetModelKernel : public RoundKernel {
|
|||
void GetModel(const schema::RequestGetModel *get_model_req, const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||
void BuildGetModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
|
||||
const std::string &reason, const size_t iter,
|
||||
const std::map<std::string, AddressPtr> &feature_maps, const std::string ×tamp);
|
||||
const std::map<std::string, AddressPtr> &feature_maps, const std::string ×tamp,
|
||||
const schema::CompressType &compressType = schema::CompressType_NO_COMPRESS,
|
||||
const std::map<std::string, AddressPtr> &compress_feature_maps = {});
|
||||
|
||||
// The executor is for getting model for getModel request.
|
||||
Executor *executor_;
|
||||
|
|
|
@ -126,10 +126,19 @@ bool StartFLJobKernel::Launch(const uint8_t *req_data, size_t len,
|
|||
IncreaseAcceptClientNum();
|
||||
auto curr_iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
||||
auto last_iteration = curr_iter_num - 1;
|
||||
auto cache = ModelStore::GetInstance().GetModelResponseCache(name_, curr_iter_num, last_iteration);
|
||||
auto download_compress_types = start_fl_job_req->download_compress_types();
|
||||
schema::CompressType compressType =
|
||||
mindspore::fl::compression::CompressExecutor::GetInstance().GetCompressType(download_compress_types);
|
||||
std::string compress_type;
|
||||
if (compressType == schema::CompressType_QUANT) {
|
||||
compress_type = kQuant;
|
||||
} else {
|
||||
compress_type = kNoCompress;
|
||||
}
|
||||
auto cache = ModelStore::GetInstance().GetModelResponseCache(name_, curr_iter_num, last_iteration, compress_type);
|
||||
if (cache == nullptr) {
|
||||
StartFLJob(fbb);
|
||||
cache = ModelStore::GetInstance().StoreModelResponseCache(name_, curr_iter_num, last_iteration,
|
||||
StartFLJob(fbb, device_meta, start_fl_job_req);
|
||||
cache = ModelStore::GetInstance().StoreModelResponseCache(name_, curr_iter_num, last_iteration, compress_type,
|
||||
fbb->GetBufferPointer(), fbb->GetSize());
|
||||
if (cache == nullptr) {
|
||||
SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
|
@ -303,22 +312,40 @@ ResultCode StartFLJobKernel::CountForStartFLJob(const std::shared_ptr<FBBuilder>
|
|||
return ResultCode::kSuccess;
|
||||
}
|
||||
|
||||
void StartFLJobKernel::StartFLJob(const std::shared_ptr<FBBuilder> &fbb) {
|
||||
void StartFLJobKernel::StartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &,
|
||||
const schema::RequestFLJob *start_fl_job_req) {
|
||||
size_t last_iteration = LocalMetaStore::GetInstance().curr_iter_num() - 1;
|
||||
auto feature_maps = ModelStore::GetInstance().GetModelByIterNum(last_iteration);
|
||||
if (feature_maps.empty()) {
|
||||
MS_LOG(WARNING) << "The feature map for startFLJob is empty.";
|
||||
|
||||
std::map<std::string, AddressPtr> feature_maps = {};
|
||||
std::map<std::string, AddressPtr> compress_feature_maps = {};
|
||||
|
||||
// Only download compress weights if client support.
|
||||
auto download_compress_types = start_fl_job_req->download_compress_types();
|
||||
schema::CompressType compressType =
|
||||
mindspore::fl::compression::CompressExecutor::GetInstance().GetCompressType(download_compress_types);
|
||||
if (compressType == schema::CompressType_NO_COMPRESS) {
|
||||
feature_maps = ModelStore::GetInstance().GetModelByIterNum(last_iteration);
|
||||
if (feature_maps.empty()) {
|
||||
MS_LOG(WARNING) << "The feature map for startFLJob is empty.";
|
||||
}
|
||||
} else {
|
||||
if (mindspore::fl::compression::CompressExecutor::GetInstance().EnableCompressWeight(compressType)) {
|
||||
compress_feature_maps = ModelStore::GetInstance().GetCompressModelByIterNum(last_iteration, compressType);
|
||||
}
|
||||
}
|
||||
|
||||
BuildStartFLJobRsp(fbb, schema::ResponseCode_SUCCEED, "success", true,
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)),
|
||||
feature_maps);
|
||||
feature_maps, compressType, compress_feature_maps);
|
||||
return;
|
||||
}
|
||||
|
||||
void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
|
||||
const std::string &reason, const bool is_selected,
|
||||
const std::string &next_req_time,
|
||||
std::map<std::string, AddressPtr> feature_maps) {
|
||||
const std::map<std::string, AddressPtr> &feature_maps,
|
||||
const schema::CompressType &compressType,
|
||||
const std::map<std::string, AddressPtr> &compress_feature_maps) {
|
||||
if (fbb == nullptr) {
|
||||
MS_LOG(WARNING) << "Input fbb is nullptr.";
|
||||
return;
|
||||
|
@ -350,6 +377,12 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
|
|||
auto cipher_public_params =
|
||||
schema::CreateCipherPublicParams(*fbb.get(), encrypt_type, pw_params, dp_params, ds_params);
|
||||
#endif
|
||||
schema::CompressType upload_compress_type;
|
||||
if (ps::PSContext::instance()->upload_compress_type() == kDiffSparseQuant) {
|
||||
upload_compress_type = schema::CompressType_DIFF_SPARSE_QUANT;
|
||||
} else {
|
||||
upload_compress_type = schema::CompressType_NO_COMPRESS;
|
||||
}
|
||||
|
||||
schema::FLPlanBuilder fl_plan_builder(*(fbb.get()));
|
||||
fl_plan_builder.add_fl_name(fbs_fl_name);
|
||||
|
@ -375,6 +408,33 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
|
|||
}
|
||||
auto fbs_feature_maps_vector = fbb->CreateVector(fbs_feature_maps);
|
||||
|
||||
// construct compress feature maps with fbs
|
||||
std::vector<flatbuffers::Offset<schema::CompressFeatureMap>> fbs_compress_feature_maps;
|
||||
for (const auto &compress_feature_map : compress_feature_maps) {
|
||||
if (compressType == schema::CompressType_QUANT) {
|
||||
if (compress_feature_map.first.find(kMinVal) != string::npos ||
|
||||
compress_feature_map.first.find(kMaxVal) != string::npos) {
|
||||
continue;
|
||||
}
|
||||
auto fbs_compress_weight_fullname = fbb->CreateString(compress_feature_map.first);
|
||||
auto fbs_compress_weight_data = fbb->CreateVector(reinterpret_cast<int8_t *>(compress_feature_map.second->addr),
|
||||
compress_feature_map.second->size / sizeof(int8_t));
|
||||
|
||||
const std::string min_val_name = compress_feature_map.first + "." + kMinVal;
|
||||
const std::string max_val_name = compress_feature_map.first + "." + kMaxVal;
|
||||
|
||||
const AddressPtr min_val_ptr = compress_feature_maps.at(min_val_name);
|
||||
const AddressPtr max_val_ptr = compress_feature_maps.at(max_val_name);
|
||||
|
||||
float *fbs_min_val_ptr = reinterpret_cast<float *>(min_val_ptr->addr);
|
||||
float *fbs_max_val_ptr = reinterpret_cast<float *>(max_val_ptr->addr);
|
||||
auto fbs_compress_feature_map = schema::CreateCompressFeatureMap(
|
||||
*(fbb.get()), fbs_compress_weight_fullname, fbs_compress_weight_data, *fbs_min_val_ptr, *fbs_max_val_ptr);
|
||||
fbs_compress_feature_maps.push_back(fbs_compress_feature_map);
|
||||
}
|
||||
}
|
||||
auto fbs_compress_feature_maps_vector = fbb->CreateVector(fbs_compress_feature_maps);
|
||||
|
||||
schema::ResponseFLJobBuilder rsp_fl_job_builder(*(fbb.get()));
|
||||
rsp_fl_job_builder.add_retcode(static_cast<int>(retcode));
|
||||
rsp_fl_job_builder.add_reason(fbs_reason);
|
||||
|
@ -383,6 +443,10 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
|
|||
rsp_fl_job_builder.add_next_req_time(fbs_next_req_time);
|
||||
rsp_fl_job_builder.add_fl_plan_config(fbs_fl_plan);
|
||||
rsp_fl_job_builder.add_feature_map(fbs_feature_maps_vector);
|
||||
rsp_fl_job_builder.add_download_compress_type(compressType);
|
||||
rsp_fl_job_builder.add_compress_feature_map(fbs_compress_feature_maps_vector);
|
||||
rsp_fl_job_builder.add_upload_compress_type(upload_compress_type);
|
||||
rsp_fl_job_builder.add_upload_sparse_rate(ps::PSContext::instance()->upload_sparse_rate());
|
||||
auto rsp_fl_job = rsp_fl_job_builder.Finish();
|
||||
fbb->Finish(rsp_fl_job);
|
||||
return;
|
||||
|
|
|
@ -25,6 +25,9 @@
|
|||
#include "fl/server/executor.h"
|
||||
#include "fl/server/kernel/round/round_kernel.h"
|
||||
#include "fl/server/kernel/round/round_kernel_factory.h"
|
||||
#include "schema/fl_job_generated.h"
|
||||
#include "schema/cipher_generated.h"
|
||||
#include "fl/compression/encode_executor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace fl {
|
||||
|
@ -56,7 +59,8 @@ class StartFLJobKernel : public RoundKernel {
|
|||
// Distributed count service counts for startFLJob.
|
||||
ResultCode CountForStartFLJob(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestFLJob *start_fl_job_req);
|
||||
|
||||
void StartFLJob(const std::shared_ptr<FBBuilder> &fbb);
|
||||
void StartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta,
|
||||
const schema::RequestFLJob *start_fl_job_req);
|
||||
|
||||
bool JudgeFLJobCert(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestFLJob *start_fl_job_req);
|
||||
|
||||
|
@ -65,7 +69,9 @@ class StartFLJobKernel : public RoundKernel {
|
|||
// Build response for startFLJob round no matter success or failure.
|
||||
void BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
|
||||
const std::string &reason, const bool is_selected, const std::string &next_req_time,
|
||||
std::map<std::string, AddressPtr> feature_maps = {});
|
||||
const std::map<std::string, AddressPtr> &feature_maps = {},
|
||||
const schema::CompressType &compressType = schema::CompressType_NO_COMPRESS,
|
||||
const std::map<std::string, AddressPtr> &compress_feature_maps = {});
|
||||
|
||||
// The executor is for getting the initial model for startFLJob request.
|
||||
Executor *executor_;
|
||||
|
|
|
@ -201,23 +201,27 @@ ResultCode UpdateModelKernel::VerifyUpdateModel(const schema::RequestUpdateModel
|
|||
}
|
||||
|
||||
std::unordered_map<std::string, size_t> feature_map;
|
||||
auto upload_feature_map = update_model_req->feature_map();
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(upload_feature_map, ResultCode::kFail);
|
||||
for (uint32_t i = 0; i < upload_feature_map->size(); i++) {
|
||||
const auto &item = upload_feature_map->Get(i);
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(item, ResultCode::kFail);
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(item->weight_fullname(), ResultCode::kFail);
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(item->data(), ResultCode::kFail);
|
||||
if (ps::PSContext::instance()->upload_compress_type() != kDiffSparseQuant) {
|
||||
auto upload_feature_map = update_model_req->feature_map();
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(upload_feature_map, ResultCode::kFail);
|
||||
for (uint32_t i = 0; i < upload_feature_map->size(); i++) {
|
||||
const auto &item = upload_feature_map->Get(i);
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(item, ResultCode::kFail);
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(item->weight_fullname(), ResultCode::kFail);
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(item->data(), ResultCode::kFail);
|
||||
|
||||
std::string weight_full_name = item->weight_fullname()->str();
|
||||
size_t weight_size = item->data()->size() * sizeof(float);
|
||||
feature_map[weight_full_name] = weight_size;
|
||||
std::string weight_full_name = item->weight_fullname()->str();
|
||||
size_t weight_size = item->data()->size() * sizeof(float);
|
||||
feature_map[weight_full_name] = weight_size;
|
||||
}
|
||||
}
|
||||
|
||||
bool verifyFeatureMapIsSuccess;
|
||||
if (ps::PSContext::instance()->encrypt_type() == ps::kDSEncryptType && update_model_req->sign() != 0) {
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(update_model_req->index_array(), ResultCode::kFail);
|
||||
verifyFeatureMapIsSuccess = VerifySignDSFeatureMap(feature_map, update_model_req);
|
||||
} else if (ps::PSContext::instance()->upload_compress_type() == kDiffSparseQuant) {
|
||||
verifyFeatureMapIsSuccess = VerifyUploadCompressFeatureMap(update_model_req);
|
||||
} else {
|
||||
verifyFeatureMapIsSuccess = LocalMetaStore::GetInstance().verifyAggregationFeatureMap(feature_map);
|
||||
}
|
||||
|
@ -280,6 +284,45 @@ bool UpdateModelKernel::VerifySignDSFeatureMap(const std::unordered_map<std::str
|
|||
return true;
|
||||
}
|
||||
|
||||
bool UpdateModelKernel::VerifyUploadCompressFeatureMap(const schema::RequestUpdateModel *update_model_req) {
|
||||
auto &aggregation_feature_map_ = LocalMetaStore::GetInstance().aggregation_feature_map();
|
||||
auto upload_sparse_rate = update_model_req->upload_sparse_rate();
|
||||
if (upload_sparse_rate != ps::PSContext::instance()->upload_sparse_rate()) {
|
||||
MS_LOG(WARNING) << "The upload_sparse_rate must be equal to the setting in context.";
|
||||
return false;
|
||||
}
|
||||
auto fbs_name_vec = update_model_req->name_vec();
|
||||
if (fbs_name_vec == nullptr) {
|
||||
MS_LOG(WARNING) << "The name_vec is null.";
|
||||
return false;
|
||||
}
|
||||
if (fbs_name_vec->size() == 0) {
|
||||
MS_LOG(WARNING) << "The size of name_vec must be larger than 0.";
|
||||
return false;
|
||||
}
|
||||
if (fbs_name_vec->size() > aggregation_feature_map_.size()) {
|
||||
MS_LOG(WARNING) << "The size of name_vec must be smaller than model in server.";
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < fbs_name_vec->size(); ++i) {
|
||||
std::string name = fbs_name_vec->Get(i)->str();
|
||||
if (aggregation_feature_map_.count(name) == 0) {
|
||||
MS_LOG(WARNING) << "The upload name: " << name << " is not in model in server.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
auto fbs_compress_feature_map = update_model_req->compress_feature_map();
|
||||
if (fbs_compress_feature_map == nullptr) {
|
||||
MS_LOG(WARNING) << "The upload compress feature map is null.";
|
||||
return false;
|
||||
}
|
||||
if (fbs_compress_feature_map->size() == 0) {
|
||||
MS_LOG(WARNING) << "The upload compress feature map is empty.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_model_req,
|
||||
const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta) {
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kFail);
|
||||
|
@ -292,6 +335,8 @@ ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *upda
|
|||
std::map<std::string, UploadData> feature_map;
|
||||
if (ps::PSContext::instance()->encrypt_type() == ps::kDSEncryptType) {
|
||||
feature_map = ParseSignDSFeatureMap(update_model_req, data_size, &weight_map);
|
||||
} else if (ps::PSContext::instance()->upload_compress_type() == kDiffSparseQuant) {
|
||||
feature_map = ParseUploadCompressFeatureMap(update_model_req, data_size, &weight_map);
|
||||
} else {
|
||||
feature_map = ParseFeatureMap(update_model_req);
|
||||
}
|
||||
|
@ -397,6 +442,89 @@ std::map<std::string, UploadData> UpdateModelKernel::ParseSignDSFeatureMap(
|
|||
return feature_map;
|
||||
}
|
||||
|
||||
std::map<std::string, UploadData> UpdateModelKernel::ParseUploadCompressFeatureMap(
|
||||
const schema::RequestUpdateModel *update_model_req, size_t data_size,
|
||||
std::map<std::string, std::vector<float>> *weight_map) {
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, {});
|
||||
std::map<std::string, UploadData> feature_map;
|
||||
schema::CompressType upload_compress_type = update_model_req->upload_compress_type();
|
||||
upload_compress_type =
|
||||
mindspore::fl::compression::DecodeExecutor::GetInstance().GetCompressType(upload_compress_type);
|
||||
MS_LOG(INFO) << "This schema upload compress type is: " << upload_compress_type;
|
||||
if (upload_compress_type != schema::CompressType_NO_COMPRESS) {
|
||||
MS_LOG(INFO) << "This upload compress type is DIFF_SPARSE_QUANT.";
|
||||
feature_map = DecodeFeatureMap(weight_map, update_model_req, upload_compress_type, data_size);
|
||||
return feature_map;
|
||||
}
|
||||
MS_LOG(INFO) << "This upload compress type is NO_COMPRESS.";
|
||||
// Some clients upload origin weights.
|
||||
auto fbs_feature_map = update_model_req->feature_map();
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(fbs_feature_map, feature_map);
|
||||
for (uint32_t i = 0; i < fbs_feature_map->size(); i++) {
|
||||
std::string weight_full_name = fbs_feature_map->Get(i)->weight_fullname()->str();
|
||||
float *weight_data = const_cast<float *>(fbs_feature_map->Get(i)->data()->data());
|
||||
size_t weight_size = fbs_feature_map->Get(i)->data()->size() * sizeof(float);
|
||||
UploadData upload_data;
|
||||
upload_data[kNewWeight].addr = weight_data;
|
||||
upload_data[kNewWeight].size = weight_size;
|
||||
feature_map[weight_full_name] = upload_data;
|
||||
}
|
||||
return feature_map;
|
||||
}
|
||||
|
||||
std::map<std::string, UploadData> UpdateModelKernel::DecodeFeatureMap(
|
||||
std::map<std::string, std::vector<float>> *weight_map, const schema::RequestUpdateModel *update_model_req,
|
||||
schema::CompressType upload_compress_type, size_t data_size) {
|
||||
std::map<std::string, UploadData> feature_map;
|
||||
|
||||
// Get and set decode hyper parameters.
|
||||
auto seed = update_model_req->iteration();
|
||||
MS_LOG(INFO) << "The seed for compression is: " << seed;
|
||||
auto upload_sparse_rate = update_model_req->upload_sparse_rate();
|
||||
MS_LOG(INFO) << "The upload_sparse_rate for compression is: " << upload_sparse_rate;
|
||||
// Get name vector.
|
||||
auto fbs_name_vec = update_model_req->name_vec();
|
||||
std::vector<std::string> name_vec;
|
||||
for (size_t i = 0; i < fbs_name_vec->size(); ++i) {
|
||||
name_vec.emplace_back(fbs_name_vec->Get(i)->str());
|
||||
}
|
||||
|
||||
// Parameter process for decode.
|
||||
auto fbs_compress_feature_map = update_model_req->compress_feature_map();
|
||||
std::vector<mindspore::fl::compression::CompressFeatureMap> compress_feature_maps;
|
||||
for (size_t i = 0; i < fbs_compress_feature_map->size(); ++i) {
|
||||
mindspore::fl::compression::CompressFeatureMap compress_feature_map;
|
||||
int8_t *compress_weight_data = const_cast<int8_t *>(fbs_compress_feature_map->Get(i)->compress_data()->data());
|
||||
size_t compress_weight_size = fbs_compress_feature_map->Get(i)->compress_data()->size();
|
||||
MS_LOG(INFO) << "The compress weight size: " << compress_weight_size;
|
||||
for (size_t j = 0; j < compress_weight_size; ++j) {
|
||||
compress_feature_map.compress_data.emplace_back(compress_weight_data[j]);
|
||||
}
|
||||
compress_feature_map.min_val = fbs_compress_feature_map->Get(i)->min_val();
|
||||
compress_feature_map.max_val = fbs_compress_feature_map->Get(i)->max_val();
|
||||
MS_LOG(INFO) << "Min value: " << compress_feature_map.min_val;
|
||||
MS_LOG(INFO) << "Max value: " << compress_feature_map.max_val;
|
||||
compress_feature_maps.emplace_back(compress_feature_map);
|
||||
}
|
||||
|
||||
// Decode.
|
||||
bool status = mindspore::fl::compression::DecodeExecutor::GetInstance().Decode(
|
||||
weight_map, compress_feature_maps, upload_compress_type, upload_sparse_rate, seed, name_vec, data_size);
|
||||
if (status) {
|
||||
for (size_t i = 0; i < name_vec.size(); ++i) {
|
||||
std::string weight_full_name = name_vec[i];
|
||||
size_t weight_size = (*weight_map)[weight_full_name].size() * sizeof(float);
|
||||
UploadData upload_data;
|
||||
upload_data[kNewWeight].addr = (*weight_map)[weight_full_name].data();
|
||||
upload_data[kNewWeight].size = weight_size;
|
||||
feature_map[weight_full_name] = upload_data;
|
||||
}
|
||||
return feature_map;
|
||||
}
|
||||
MS_LOG(WARNING) << "Decode failed!";
|
||||
return feature_map;
|
||||
}
|
||||
|
||||
ResultCode UpdateModelKernel::CountForAggregation(const std::string &req_fl_id) {
|
||||
std::string count_reason = "";
|
||||
if (!DistributedCountService::GetInstance().Count(kCountForAggregation, req_fl_id, &count_reason)) {
|
||||
|
|
|
@ -30,6 +30,9 @@
|
|||
#ifdef ENABLE_ARMOUR
|
||||
#include "fl/armour/cipher/cipher_meta_storage.h"
|
||||
#endif
|
||||
#include "fl/compression/decode_executor.h"
|
||||
#include "schema/fl_job_generated.h"
|
||||
#include "schema/cipher_generated.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace fl {
|
||||
|
@ -64,8 +67,12 @@ class UpdateModelKernel : public RoundKernel {
|
|||
std::map<std::string, UploadData> ParseSignDSFeatureMap(const schema::RequestUpdateModel *update_model_req,
|
||||
size_t data_size,
|
||||
std::map<std::string, std::vector<float>> *weight_map);
|
||||
std::map<std::string, UploadData> ParseUploadCompressFeatureMap(
|
||||
const schema::RequestUpdateModel *update_model_req, size_t data_size,
|
||||
std::map<std::string, std::vector<float>> *weight_map);
|
||||
bool VerifySignDSFeatureMap(const std::unordered_map<std::string, size_t> &model,
|
||||
const schema::RequestUpdateModel *update_model_req);
|
||||
bool VerifyUploadCompressFeatureMap(const schema::RequestUpdateModel *update_model_req);
|
||||
ResultCode CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb,
|
||||
const schema::RequestUpdateModel *update_model_req);
|
||||
sigVerifyResult VerifySignature(const schema::RequestUpdateModel *update_model_req);
|
||||
|
@ -78,6 +85,11 @@ class UpdateModelKernel : public RoundKernel {
|
|||
|
||||
// The time window of one iteration.
|
||||
size_t iteration_time_window_{0};
|
||||
|
||||
// Decode functions of compression.
|
||||
std::map<std::string, UploadData> DecodeFeatureMap(std::map<std::string, std::vector<float>> *weight_map,
|
||||
const schema::RequestUpdateModel *update_model_req,
|
||||
schema::CompressType upload_compress_type, size_t data_size);
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
|
|
|
@ -44,6 +44,11 @@ void MemoryRegister::StoreCharArray(std::unique_ptr<char[]> *array) {
|
|||
MS_ERROR_IF_NULL_WO_RET_VAL(array);
|
||||
char_arrays_.push_back(std::move(*array));
|
||||
}
|
||||
|
||||
void MemoryRegister::StoreFloat32(std::unique_ptr<float> *param) {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(param);
|
||||
float_params_.push_back(std::move(*param));
|
||||
}
|
||||
} // namespace server
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include <utility>
|
||||
#include <typeinfo>
|
||||
#include "fl/server/common.h"
|
||||
#include "fl/compression/encode_executor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace fl {
|
||||
|
@ -70,6 +71,25 @@ class MemoryRegister {
|
|||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void RegisterParameter(const std::string &name, std::unique_ptr<T> *param, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
void *data = param->get();
|
||||
AddressPtr addressPtr = std::make_shared<Address>();
|
||||
addressPtr->addr = data;
|
||||
addressPtr->size = size;
|
||||
if (typeid(T) == typeid(float)) {
|
||||
auto float_param = CastUniqueParamPtr<float, T>(param);
|
||||
StoreFloat32(&float_param);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "MemoryRegister does not support type " << typeid(T).name();
|
||||
return;
|
||||
}
|
||||
|
||||
RegisterAddressPtr(name, addressPtr);
|
||||
return;
|
||||
}
|
||||
|
||||
private:
|
||||
std::map<std::string, AddressPtr> addresses_;
|
||||
std::vector<std::unique_ptr<float[]>> float_arrays_;
|
||||
|
@ -86,6 +106,15 @@ class MemoryRegister {
|
|||
std::unique_ptr<T[]> CastUniquePtr(std::unique_ptr<S[]> *array) {
|
||||
return std::unique_ptr<T[]>{reinterpret_cast<T *>(array->release())};
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<float>> float_params_;
|
||||
|
||||
void StoreFloat32(std::unique_ptr<float> *array);
|
||||
|
||||
template <typename T, typename S>
|
||||
std::unique_ptr<T> CastUniqueParamPtr(std::unique_ptr<S> *param) {
|
||||
return std::unique_ptr<T>{reinterpret_cast<T *>(param->release())};
|
||||
}
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace fl
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <string>
|
||||
#include <memory>
|
||||
#include "fl/server/executor.h"
|
||||
#include "pipeline/jit/parse/parse.h"
|
||||
#include "include/common/utils/python_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -33,6 +34,10 @@ void ModelStore::Initialize(uint32_t rank_id, uint32_t max_count) {
|
|||
max_model_count_ = max_count;
|
||||
initial_model_ = AssignNewModelMemory();
|
||||
iteration_to_model_[kInitIterationNum] = initial_model_;
|
||||
std::map<std::string, AddressPtr> model = Executor::GetInstance().GetModel();
|
||||
for (const auto &item : mindspore::fl::compression::kCompressTypeMap) {
|
||||
iteration_to_compress_model_[kInitIterationNum][item.first] = AssignNewCompressModelMemory(item.first, model);
|
||||
}
|
||||
model_size_ = ComputeModelSize();
|
||||
MS_LOG(INFO) << "Model store checkpoint dir is: " << ps::PSContext::instance()->checkpoint_dir();
|
||||
}
|
||||
|
@ -101,6 +106,24 @@ std::map<std::string, AddressPtr> ModelStore::GetModelByIterNum(size_t iteration
|
|||
return model;
|
||||
}
|
||||
|
||||
std::map<std::string, AddressPtr> ModelStore::GetCompressModelByIterNum(size_t iteration,
|
||||
schema::CompressType compressType) {
|
||||
std::unique_lock<std::mutex> lock(model_mtx_);
|
||||
std::map<std::string, AddressPtr> compressModel = {};
|
||||
if (iteration_to_compress_model_.count(iteration) == 0) {
|
||||
MS_LOG(ERROR) << "Compress Model for iteration " << iteration << " is not stored.";
|
||||
return compressModel;
|
||||
}
|
||||
std::map<schema::CompressType, std::shared_ptr<MemoryRegister>> compress_model_map =
|
||||
iteration_to_compress_model_[iteration];
|
||||
if (compress_model_map.count(compressType) == 0) {
|
||||
MS_LOG(ERROR) << "Compress Model for compress type " << compressType << " is not stored.";
|
||||
return compressModel;
|
||||
}
|
||||
compressModel = iteration_to_compress_model_[iteration][compressType]->addresses();
|
||||
return compressModel;
|
||||
}
|
||||
|
||||
void ModelStore::Reset() {
|
||||
std::unique_lock<std::mutex> lock(model_mtx_);
|
||||
initial_model_ = iteration_to_model_.rbegin()->second;
|
||||
|
@ -114,6 +137,11 @@ const std::map<size_t, std::shared_ptr<MemoryRegister>> &ModelStore::iteration_t
|
|||
return iteration_to_model_;
|
||||
}
|
||||
|
||||
const std::map<size_t, CompressTypeMap> &ModelStore::iteration_to_compress_model() {
|
||||
std::unique_lock<std::mutex> lock(model_mtx_);
|
||||
return iteration_to_compress_model_;
|
||||
}
|
||||
|
||||
size_t ModelStore::model_size() const { return model_size_; }
|
||||
|
||||
std::shared_ptr<MemoryRegister> ModelStore::AssignNewModelMemory() {
|
||||
|
@ -146,6 +174,86 @@ std::shared_ptr<MemoryRegister> ModelStore::AssignNewModelMemory() {
|
|||
return memory_register;
|
||||
}
|
||||
|
||||
std::shared_ptr<MemoryRegister> ModelStore::AssignNewCompressModelMemory(
|
||||
schema::CompressType compressType, const std::map<std::string, AddressPtr> &model) {
|
||||
if (model.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Model feature map is empty.";
|
||||
return nullptr;
|
||||
}
|
||||
std::map<string, std::vector<float>> feature_maps;
|
||||
for (auto &feature_map : model) {
|
||||
auto weight_fullname = feature_map.first;
|
||||
auto weight_data = reinterpret_cast<float *>(feature_map.second->addr);
|
||||
std::vector<float> weight_data_vector{weight_data, weight_data + feature_map.second->size / sizeof(float)};
|
||||
feature_maps[weight_fullname] = weight_data_vector;
|
||||
}
|
||||
|
||||
std::map<std::string, mindspore::fl::compression::CompressWeight> compressWeights;
|
||||
bool status = mindspore::fl::compression::CompressExecutor::GetInstance().construct_compress_weight(
|
||||
&compressWeights, feature_maps, compressType);
|
||||
if (!status) {
|
||||
MS_LOG(ERROR) << "Encode failed!";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Assign new memory for the compress model.
|
||||
std::shared_ptr<MemoryRegister> memory_register = std::make_shared<MemoryRegister>();
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(memory_register, nullptr);
|
||||
MS_LOG(INFO) << "Register compressWeight for compressType: " << schema::EnumNameCompressType(compressType);
|
||||
|
||||
for (const auto &compressWeight : compressWeights) {
|
||||
if (compressType == schema::CompressType_QUANT) {
|
||||
std::string compress_weight_name = compressWeight.first;
|
||||
std::string min_val_name = compress_weight_name + "." + kMinVal;
|
||||
std::string max_val_name = compress_weight_name + "." + kMaxVal;
|
||||
size_t compress_weight_size = compressWeight.second.compress_data_len * sizeof(int8_t);
|
||||
auto compress_weight_data = std::make_unique<char[]>(compress_weight_size);
|
||||
auto src_data_size = compress_weight_size;
|
||||
auto dst_data_size = compress_weight_size;
|
||||
int ret =
|
||||
memcpy_s(compress_weight_data.get(), dst_data_size, compressWeight.second.compress_data.data(), src_data_size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
|
||||
return nullptr;
|
||||
}
|
||||
memory_register->RegisterArray(compress_weight_name, &compress_weight_data, compress_weight_size);
|
||||
size_t float_size = 1;
|
||||
auto min_val_ptr = std::make_unique<float>(compressWeight.second.min_val);
|
||||
auto max_val_ptr = std::make_unique<float>(compressWeight.second.max_val);
|
||||
|
||||
memory_register->RegisterParameter(min_val_name, &min_val_ptr, float_size);
|
||||
memory_register->RegisterParameter(max_val_name, &max_val_ptr, float_size);
|
||||
}
|
||||
}
|
||||
return memory_register;
|
||||
}
|
||||
|
||||
void ModelStore::StoreCompressModelByIterNum(size_t iteration, const std::map<std::string, AddressPtr> &new_model) {
|
||||
std::unique_lock<std::mutex> lock(model_mtx_);
|
||||
if (iteration_to_compress_model_.count(iteration) != 0) {
|
||||
MS_LOG(WARNING) << "Compress Model for iteration " << iteration << " is already stored";
|
||||
return;
|
||||
}
|
||||
if (new_model.empty()) {
|
||||
MS_LOG(ERROR) << "Compress Model feature map is empty.";
|
||||
return;
|
||||
}
|
||||
|
||||
iteration_to_compress_model_[iteration] = {};
|
||||
if (iteration_to_compress_model_.size() >= max_model_count_) {
|
||||
auto compress_model_map = iteration_to_compress_model_.begin()->second;
|
||||
compress_model_map.clear();
|
||||
(void)iteration_to_compress_model_.erase(iteration_to_compress_model_.begin());
|
||||
}
|
||||
|
||||
for (const auto &item : mindspore::fl::compression::kCompressTypeMap) {
|
||||
auto memory_register = AssignNewCompressModelMemory(item.first, new_model);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(memory_register);
|
||||
iteration_to_compress_model_[iteration][item.first] = memory_register;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
size_t ModelStore::ComputeModelSize() {
|
||||
std::unique_lock<std::mutex> lock(model_mtx_);
|
||||
if (iteration_to_model_.empty()) {
|
||||
|
@ -179,13 +287,15 @@ void ModelStore::RelModelResponseCache(const void *data, size_t datalen, void *e
|
|||
|
||||
std::shared_ptr<std::vector<uint8_t>> ModelStore::GetModelResponseCache(const string &round_name,
|
||||
size_t cur_iteration_num,
|
||||
size_t model_iteration_num) {
|
||||
size_t model_iteration_num,
|
||||
const std::string &compress_type) {
|
||||
std::unique_lock<std::mutex> lock(model_response_cache_lock_);
|
||||
auto it = std::find_if(model_response_cache_.begin(), model_response_cache_.end(),
|
||||
[&round_name, cur_iteration_num, model_iteration_num](const HttpResponseModelCache &item) {
|
||||
return item.round_name == round_name && item.cur_iteration_num == cur_iteration_num &&
|
||||
item.model_iteration_num == model_iteration_num;
|
||||
});
|
||||
auto it = std::find_if(
|
||||
model_response_cache_.begin(), model_response_cache_.end(),
|
||||
[&round_name, cur_iteration_num, model_iteration_num, &compress_type](const HttpResponseModelCache &item) {
|
||||
return item.round_name == round_name && item.cur_iteration_num == cur_iteration_num &&
|
||||
item.model_iteration_num == model_iteration_num && item.compress_type == compress_type;
|
||||
});
|
||||
if (it == model_response_cache_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -196,14 +306,16 @@ std::shared_ptr<std::vector<uint8_t>> ModelStore::GetModelResponseCache(const st
|
|||
|
||||
std::shared_ptr<std::vector<uint8_t>> ModelStore::StoreModelResponseCache(const string &round_name,
|
||||
size_t cur_iteration_num,
|
||||
size_t model_iteration_num, const void *data,
|
||||
size_t datalen) {
|
||||
size_t model_iteration_num,
|
||||
const std::string &compress_type,
|
||||
const void *data, size_t datalen) {
|
||||
std::unique_lock<std::mutex> lock(model_response_cache_lock_);
|
||||
auto it = std::find_if(model_response_cache_.begin(), model_response_cache_.end(),
|
||||
[&round_name, cur_iteration_num, model_iteration_num](const HttpResponseModelCache &item) {
|
||||
return item.round_name == round_name && item.cur_iteration_num == cur_iteration_num &&
|
||||
item.model_iteration_num == model_iteration_num;
|
||||
});
|
||||
auto it = std::find_if(
|
||||
model_response_cache_.begin(), model_response_cache_.end(),
|
||||
[&round_name, cur_iteration_num, model_iteration_num, &compress_type](const HttpResponseModelCache &item) {
|
||||
return item.round_name == round_name && item.cur_iteration_num == cur_iteration_num &&
|
||||
item.model_iteration_num == model_iteration_num && item.compress_type == compress_type;
|
||||
});
|
||||
if (it != model_response_cache_.end()) {
|
||||
it->reference_count += 1;
|
||||
total_add_reference_count += 1;
|
||||
|
@ -223,6 +335,7 @@ std::shared_ptr<std::vector<uint8_t>> ModelStore::StoreModelResponseCache(const
|
|||
item.round_name = round_name;
|
||||
item.cur_iteration_num = cur_iteration_num;
|
||||
item.model_iteration_num = model_iteration_num;
|
||||
item.compress_type = compress_type;
|
||||
item.cache = cache;
|
||||
item.reference_count = 1;
|
||||
total_add_reference_count += 1;
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "fl/server/common.h"
|
||||
#include "fl/server/memory_register.h"
|
||||
#include "fl/server/executor.h"
|
||||
#include "fl/compression/encode_executor.h"
|
||||
#include "fl/server/local_meta_store.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -36,6 +37,9 @@ constexpr size_t kInitIterationNum = 0;
|
|||
// The initial iteration number after ModelStore is reset.
|
||||
constexpr size_t kResetInitialIterNum = 1;
|
||||
|
||||
// The compress type map.
|
||||
using CompressTypeMap = std::map<schema::CompressType, std::shared_ptr<MemoryRegister>>;
|
||||
|
||||
// Server framework use ModelStore to store and query models.
|
||||
// ModelStore stores multiple models because worker could get models of the previous iterations.
|
||||
class ModelStore {
|
||||
|
@ -64,15 +68,25 @@ class ModelStore {
|
|||
// Returns the model size, which could be calculated at the initializing phase.
|
||||
size_t model_size() const;
|
||||
|
||||
// Get compress model of the given iteration.
|
||||
std::map<std::string, AddressPtr> GetCompressModelByIterNum(size_t iteration, schema::CompressType compressType);
|
||||
|
||||
const std::map<size_t, std::map<schema::CompressType, std::shared_ptr<MemoryRegister>>>
|
||||
&iteration_to_compress_model();
|
||||
|
||||
void StoreCompressModelByIterNum(size_t iteration, const std::map<std::string, AddressPtr> &new_model);
|
||||
|
||||
static void RelModelResponseCache(const void *data, size_t datalen, void *extra);
|
||||
std::shared_ptr<std::vector<uint8_t>> GetModelResponseCache(const std::string &round_name, size_t cur_iteration_num,
|
||||
size_t model_iteration_num);
|
||||
size_t model_iteration_num,
|
||||
const std::string &compress_type);
|
||||
std::shared_ptr<std::vector<uint8_t>> StoreModelResponseCache(const std::string &round_name, size_t cur_iteration_num,
|
||||
size_t model_iteration_num, const void *data,
|
||||
size_t model_iteration_num,
|
||||
const std::string &compress_type, const void *data,
|
||||
size_t datalen);
|
||||
|
||||
private:
|
||||
ModelStore() : max_model_count_(0), model_size_(0), iteration_to_model_({}) {}
|
||||
ModelStore() : max_model_count_(0), model_size_(0), iteration_to_model_({}), iteration_to_compress_model_({}) {}
|
||||
~ModelStore() = default;
|
||||
ModelStore(const ModelStore &) = delete;
|
||||
ModelStore &operator=(const ModelStore &) = delete;
|
||||
|
@ -83,6 +97,9 @@ class ModelStore {
|
|||
// model_size_.
|
||||
std::shared_ptr<MemoryRegister> AssignNewModelMemory();
|
||||
|
||||
std::shared_ptr<MemoryRegister> AssignNewCompressModelMemory(schema::CompressType compressType,
|
||||
const std::map<std::string, AddressPtr> &model);
|
||||
|
||||
// Calculate the model size. This method should be called after iteration_to_model_ is initialized.
|
||||
size_t ComputeModelSize();
|
||||
|
||||
|
@ -95,12 +112,17 @@ class ModelStore {
|
|||
// The number of all models stored is max_model_count_.
|
||||
std::mutex model_mtx_;
|
||||
std::map<size_t, std::shared_ptr<MemoryRegister>> iteration_to_model_;
|
||||
|
||||
// iteration -> (compress type -> compress model)
|
||||
std::map<size_t, std::map<schema::CompressType, std::shared_ptr<MemoryRegister>>> iteration_to_compress_model_;
|
||||
|
||||
uint32_t rank_id_;
|
||||
|
||||
struct HttpResponseModelCache {
|
||||
std::string round_name; // startFlJob, getModel
|
||||
size_t cur_iteration_num = 0;
|
||||
size_t model_iteration_num = 0;
|
||||
std::string compress_type = kNoCompress;
|
||||
size_t reference_count = 0;
|
||||
std::shared_ptr<std::vector<uint8_t>> cache = nullptr;
|
||||
};
|
||||
|
|
|
@ -507,6 +507,12 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
.def("set_global_iteration_time_window", &PSContext::set_global_iteration_time_window,
|
||||
"Set global iteration time window.")
|
||||
.def("global_iteration_time_window", &PSContext::global_iteration_time_window, "Get global iteration time window.")
|
||||
.def("set_upload_compress_type", &PSContext::set_upload_compress_type, "Set upload compress type.")
|
||||
.def("upload_compress_type", &PSContext::upload_compress_type, "Get upload compress type.")
|
||||
.def("set_upload_sparse_rate", &PSContext::set_upload_sparse_rate, "Set upload sparse rate.")
|
||||
.def("upload_sparse_rate", &PSContext::upload_sparse_rate, "Get upload sparse rate.")
|
||||
.def("set_download_compress_type", &PSContext::set_download_compress_type, "Set download compress type.")
|
||||
.def("download_compress_type", &PSContext::download_compress_type, "Get download compress type.")
|
||||
.def("set_checkpoint_dir", &PSContext::set_checkpoint_dir, "Set server checkpoint directory.")
|
||||
.def("checkpoint_dir", &PSContext::checkpoint_dir, "Server checkpoint directory.");
|
||||
(void)m.def("_encrypt", &mindspore::pipeline::PyEncrypt, "Encrypt the data.");
|
||||
|
|
|
@ -550,6 +550,19 @@ void PSContext::set_global_iteration_time_window(const uint64_t &global_iteratio
|
|||
|
||||
uint64_t PSContext::global_iteration_time_window() const { return global_iteration_time_window_; }
|
||||
|
||||
void PSContext::set_upload_compress_type(const std::string &upload_compress_type) {
|
||||
upload_compress_type_ = upload_compress_type;
|
||||
}
|
||||
std::string PSContext::upload_compress_type() const { return upload_compress_type_; }
|
||||
|
||||
void PSContext::set_upload_sparse_rate(float upload_sparse_rate) { upload_sparse_rate_ = upload_sparse_rate; }
|
||||
float PSContext::upload_sparse_rate() const { return upload_sparse_rate_; }
|
||||
|
||||
void PSContext::set_download_compress_type(const std::string &download_compress_type) {
|
||||
download_compress_type_ = download_compress_type;
|
||||
}
|
||||
std::string PSContext::download_compress_type() const { return download_compress_type_; }
|
||||
|
||||
std::string PSContext::checkpoint_dir() const { return checkpoint_dir_; }
|
||||
|
||||
void PSContext::set_checkpoint_dir(const std::string &checkpoint_dir) { checkpoint_dir_ = checkpoint_dir; }
|
||||
|
|
|
@ -40,6 +40,7 @@ constexpr char kPWEncryptType[] = "PW_ENCRYPT";
|
|||
constexpr char kStablePWEncryptType[] = "STABLE_PW_ENCRYPT";
|
||||
constexpr char kNotEncryptType[] = "NOT_ENCRYPT";
|
||||
constexpr char kDSEncryptType[] = "SIGNDS";
|
||||
constexpr char kNoCompressType[] = "NO_COMPRESS";
|
||||
|
||||
// Use binary data to represent federated learning server's context so that we can judge which round resets the
|
||||
// iteration. From right to left, each bit stands for:
|
||||
|
@ -230,6 +231,15 @@ class PSContext {
|
|||
void set_global_iteration_time_window(const uint64_t &global_iteration_time_window);
|
||||
uint64_t global_iteration_time_window() const;
|
||||
|
||||
void set_upload_compress_type(const std::string &upload_compress_type);
|
||||
std::string upload_compress_type() const;
|
||||
|
||||
void set_upload_sparse_rate(float upload_sparse_rate);
|
||||
float upload_sparse_rate() const;
|
||||
|
||||
void set_download_compress_type(const std::string &download_compress_type);
|
||||
std::string download_compress_type() const;
|
||||
|
||||
std::string checkpoint_dir() const;
|
||||
void set_checkpoint_dir(const std::string &checkpoint_dir);
|
||||
|
||||
|
@ -286,6 +296,9 @@ class PSContext {
|
|||
server_password_(""),
|
||||
http_url_prefix_(""),
|
||||
global_iteration_time_window_(3600000),
|
||||
upload_compress_type_(kNoCompressType),
|
||||
upload_sparse_rate_(0.4f),
|
||||
download_compress_type_(kNoCompressType),
|
||||
checkpoint_dir_("") {}
|
||||
bool ps_enabled_;
|
||||
bool is_worker_;
|
||||
|
@ -419,6 +432,13 @@ class PSContext {
|
|||
|
||||
// The time window of startFLJob round in millisecond.
|
||||
uint64_t global_iteration_time_window_;
|
||||
|
||||
// Hyper parameters for upload compression.
|
||||
std::string upload_compress_type_;
|
||||
float upload_sparse_rate_;
|
||||
// Hyper parameters for download compression.
|
||||
std::string download_compress_type_;
|
||||
|
||||
// directory of server checkpoint
|
||||
std::string checkpoint_dir_;
|
||||
};
|
||||
|
|
|
@ -105,6 +105,16 @@ public class FLLiteClient {
|
|||
batchSize = flPlan.miniBatch();
|
||||
String serverMod = flPlan.serverMode();
|
||||
localFLParameter.setServerMod(serverMod);
|
||||
// Get and set hyper parameters for compression.
|
||||
byte uploadCompressType = flJob.uploadCompressType();
|
||||
LOGGER.info(Common.addTag("[startFLJob] [compression] uploadCompressType: " + uploadCompressType));
|
||||
localFLParameter.setUploadCompressType(uploadCompressType);
|
||||
float uploadSparseRate = flJob.uploadSparseRate();
|
||||
LOGGER.info(Common.addTag("[startFLJob] [compression] uploadSparseRate: " + uploadSparseRate));
|
||||
localFLParameter.setUploadSparseRatio(uploadSparseRate);
|
||||
int seed = flJob.iteration();
|
||||
LOGGER.info(Common.addTag("[startFLJob] [compression] seed: " + seed));
|
||||
localFLParameter.setSeed(seed);
|
||||
if (Common.checkFLName(flParameter.getFlName())) {
|
||||
deprecatedSetBatchSize(batchSize);
|
||||
} else {
|
||||
|
@ -446,7 +456,7 @@ public class FLLiteClient {
|
|||
return status;
|
||||
}
|
||||
|
||||
private Map<String, float[]> getFeatureMap() {
|
||||
public Map<String, float[]> getFeatureMap() {
|
||||
Map<String, float[]> featureMap = new HashMap<>();
|
||||
if (Common.checkFLName(flParameter.getFlName())) {
|
||||
featureMap = deprecatedGetFeatureMap();
|
||||
|
@ -530,8 +540,7 @@ public class FLLiteClient {
|
|||
localFLParameter.getEncryptLevel().toString() + "> : " + curStatus));
|
||||
return curStatus;
|
||||
case DP_ENCRYPT:
|
||||
// get the feature map before train
|
||||
oldFeatureMap = getFeatureMap();
|
||||
oldFeatureMap = localFLParameter.getOldFeatureMap();
|
||||
curStatus = secureProtocol.setDPParameter(iteration, dpEps, dpDelta, dpNormClipAdapt, oldFeatureMap);
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
if (curStatus != FLClientStatus.SUCCESS) {
|
||||
|
@ -542,8 +551,7 @@ public class FLLiteClient {
|
|||
LOGGER.info(Common.addTag("[Encrypt] set parameters for DP_ENCRYPT!"));
|
||||
return FLClientStatus.SUCCESS;
|
||||
case SIGNDS:
|
||||
// get the feature map before train
|
||||
oldFeatureMap = getFeatureMap();
|
||||
oldFeatureMap = localFLParameter.getOldFeatureMap();
|
||||
curStatus = secureProtocol.setDSParameter(signK, signEps, signThrRatio, signGlobalLr, signDimOut, oldFeatureMap);
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
if (curStatus != FLClientStatus.SUCCESS) {
|
||||
|
|
|
@ -18,7 +18,9 @@ package com.mindspore.flclient;
|
|||
|
||||
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||
|
||||
import com.mindspore.flclient.compression.CompressMode;
|
||||
import com.mindspore.flclient.model.RunType;
|
||||
import mindspore.schema.CompressType;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
|
@ -603,6 +605,16 @@ public class FLParameter {
|
|||
this.batchSize = batchSize;
|
||||
}
|
||||
|
||||
public byte[] getDownloadCompressTypes() {
|
||||
byte[] downloadCompressTypes = new byte[CompressMode.COMPRESS_TYPE_MAP.size()];
|
||||
int index = 0;
|
||||
for (byte downloadCompressType : CompressMode.COMPRESS_TYPE_MAP.keySet()) {
|
||||
downloadCompressTypes[index] = downloadCompressType;
|
||||
index += 1;
|
||||
}
|
||||
return downloadCompressTypes;
|
||||
}
|
||||
|
||||
public int[][] getInputShape() {
|
||||
return inputShape;
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@ package com.mindspore.flclient;
|
|||
|
||||
import com.google.flatbuffers.FlatBufferBuilder;
|
||||
|
||||
import com.mindspore.flclient.compression.DecodeExecutor;
|
||||
import com.mindspore.flclient.model.AlInferBert;
|
||||
import com.mindspore.flclient.model.AlTrainBert;
|
||||
import com.mindspore.flclient.model.Client;
|
||||
|
@ -27,11 +28,9 @@ import com.mindspore.flclient.model.SessionUtil;
|
|||
import com.mindspore.flclient.model.Status;
|
||||
|
||||
import com.mindspore.flclient.model.TrainLenet;
|
||||
import mindspore.schema.FeatureMap;
|
||||
import mindspore.schema.RequestGetModel;
|
||||
import mindspore.schema.ResponseCode;
|
||||
import mindspore.schema.ResponseGetModel;
|
||||
import mindspore.schema.*;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Date;
|
||||
import java.util.logging.Logger;
|
||||
|
@ -94,7 +93,8 @@ public class GetModel {
|
|||
throw new IllegalArgumentException();
|
||||
}
|
||||
RequestGetModelBuilder builder = new RequestGetModelBuilder();
|
||||
return builder.iteration(iteration).flName(name).time().build();
|
||||
return builder.iteration(iteration).flName(name).time()
|
||||
.downloadCompressTypesBuilder(flParameter.getDownloadCompressTypes()).build();
|
||||
}
|
||||
|
||||
private FLClientStatus deprecatedParseResponseAlbert(ResponseGetModel responseDataBuf) {
|
||||
|
@ -226,11 +226,29 @@ public class GetModel {
|
|||
return status;
|
||||
}
|
||||
|
||||
private List<FeatureMap> parseFeatureMapList(ResponseGetModel responseDataBuf) {
|
||||
List<FeatureMap> featureMaps;
|
||||
byte compressType = responseDataBuf.downloadCompressType();
|
||||
if (responseDataBuf.downloadCompressType() == mindspore.schema.CompressType.NO_COMPRESS) {
|
||||
featureMaps = new ArrayList<>();
|
||||
for (int i = 0; i < responseDataBuf.featureMapLength(); i++) {
|
||||
featureMaps.add(responseDataBuf.featureMap(i));
|
||||
}
|
||||
} else {
|
||||
List<mindspore.schema.CompressFeatureMap> compressFeatureMapList = new ArrayList<>();
|
||||
for (int i = 0; i < responseDataBuf.compressFeatureMapLength(); i++) {
|
||||
compressFeatureMapList.add(responseDataBuf.compressFeatureMap(i));
|
||||
}
|
||||
featureMaps = DecodeExecutor.getInstance().deCompressWeight(compressType, compressFeatureMapList);
|
||||
}
|
||||
return featureMaps;
|
||||
}
|
||||
|
||||
private FLClientStatus parseResponseFeatures(ResponseGetModel responseDataBuf) {
|
||||
FLClientStatus status;
|
||||
Client client = ClientManager.getClient(flParameter.getFlName());
|
||||
int fmCount = responseDataBuf.featureMapLength();
|
||||
if (fmCount <= 0) {
|
||||
List<FeatureMap> featureMapList = parseFeatureMapList(responseDataBuf);
|
||||
if (featureMapList.size() <= 0) {
|
||||
LOGGER.severe(Common.addTag("[getModel] the feature size get from server is zero"));
|
||||
retCode = ResponseCode.SystemError;
|
||||
return FLClientStatus.FAILED;
|
||||
|
@ -239,8 +257,8 @@ public class GetModel {
|
|||
LOGGER.info(Common.addTag("[getModel] parseResponseFeatures by " + localFLParameter.getServerMod()));
|
||||
ArrayList<FeatureMap> trainFeatureMaps = new ArrayList<FeatureMap>();
|
||||
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
|
||||
for (int i = 0; i < fmCount; i++) {
|
||||
FeatureMap feature = responseDataBuf.featureMap(i);
|
||||
for (int i = 0; i < featureMapList.size(); i++) {
|
||||
FeatureMap feature = featureMapList.get(i);
|
||||
if (feature == null) {
|
||||
LOGGER.severe(Common.addTag("[getModel] the feature returned from server is null"));
|
||||
retCode = ResponseCode.SystemError;
|
||||
|
@ -289,8 +307,8 @@ public class GetModel {
|
|||
} else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) {
|
||||
LOGGER.info(Common.addTag("[getModel] parseResponseFeatures by " + localFLParameter.getServerMod()));
|
||||
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
|
||||
for (int i = 0; i < fmCount; i++) {
|
||||
FeatureMap feature = responseDataBuf.featureMap(i);
|
||||
for (int i = 0; i < featureMapList.size(); i++) {
|
||||
FeatureMap feature = featureMapList.get(i);
|
||||
if (feature == null) {
|
||||
LOGGER.severe(Common.addTag("[getModel] the feature returned from server is null"));
|
||||
retCode = ResponseCode.SystemError;
|
||||
|
@ -365,6 +383,7 @@ public class GetModel {
|
|||
private int nameOffset = 0;
|
||||
private int iteration = 0;
|
||||
private int timeStampOffset = 0;
|
||||
private int downloadCompressTypesOffset = 0;
|
||||
|
||||
public RequestGetModelBuilder() {
|
||||
builder = new FlatBufferBuilder();
|
||||
|
@ -392,11 +411,23 @@ public class GetModel {
|
|||
return this;
|
||||
}
|
||||
|
||||
private RequestGetModelBuilder downloadCompressTypesBuilder(byte[] downloadCompressTypes) {
|
||||
if (downloadCompressTypes == null || downloadCompressTypes.length == 0) {
|
||||
LOGGER.severe(Common.addTag("[GetModel] the parameter of <downloadCompressTypes> is null or empty," +
|
||||
" please check!"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
this.downloadCompressTypesOffset = RequestGetModel.createDownloadCompressTypesVector(builder,
|
||||
downloadCompressTypes);
|
||||
return this;
|
||||
}
|
||||
|
||||
private byte[] build() {
|
||||
RequestGetModel.startRequestGetModel(builder);
|
||||
RequestGetModel.addFlName(builder, nameOffset);
|
||||
RequestGetModel.addIteration(builder, iteration);
|
||||
RequestGetModel.addTimestamp(builder, timeStampOffset);
|
||||
RequestGetModel.addDownloadCompressTypes(builder, downloadCompressTypesOffset);
|
||||
int root = RequestGetModel.endRequestGetModel(builder);
|
||||
builder.finish(root);
|
||||
return builder.sizedByteArray();
|
||||
|
|
|
@ -22,6 +22,7 @@ import org.bouncycastle.math.ec.rfc7748.X25519;
|
|||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
/**
|
||||
|
@ -83,6 +84,10 @@ public class LocalFLParameter {
|
|||
private MSConfig msConfig = new MSConfig();
|
||||
private boolean useSSL = true;
|
||||
private float lr = 0.1f;
|
||||
private Map<String, float[]> oldFeatureMap;
|
||||
private byte uploadCompressType = 0;
|
||||
private int seed = 0;
|
||||
private float uploadSparseRatio = 0.08f;
|
||||
|
||||
|
||||
private LocalFLParameter() {
|
||||
|
@ -250,4 +255,36 @@ public class LocalFLParameter {
|
|||
public void setLr(float lr) {
|
||||
this.lr = lr;
|
||||
}
|
||||
|
||||
public Map<String, float[]> getOldFeatureMap() {
|
||||
return oldFeatureMap;
|
||||
}
|
||||
|
||||
public void setOldFeatureMap(Map<String, float[]> oldFeatureMap) {
|
||||
this.oldFeatureMap = oldFeatureMap;
|
||||
}
|
||||
|
||||
public byte getUploadCompressType() {
|
||||
return uploadCompressType;
|
||||
}
|
||||
|
||||
public void setUploadCompressType(byte uploadCompressType) {
|
||||
this.uploadCompressType = uploadCompressType;
|
||||
}
|
||||
|
||||
public int getSeed() {
|
||||
return seed;
|
||||
}
|
||||
|
||||
public void setSeed(int seed) {
|
||||
this.seed = seed;
|
||||
}
|
||||
|
||||
public float getUploadSparseRatio() {
|
||||
return uploadSparseRatio;
|
||||
}
|
||||
|
||||
public void setUploadSparseRatio(float uploadSparseRatio) {
|
||||
this.uploadSparseRatio = uploadSparseRatio;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -208,35 +208,34 @@ public class SecureProtocol {
|
|||
* @param trainDataSize trainDataSize tne size of train data set.
|
||||
* @return the serialized model weights after adding masks.
|
||||
*/
|
||||
public int[] pwMaskModel(FlatBufferBuilder builder, int trainDataSize, Map<String, float[]> trainedMap) {
|
||||
public Map<String, List<Float>> pwMaskModel(FlatBufferBuilder builder, int trainDataSize, Map<String,
|
||||
float[]> trainedMap) {
|
||||
Map<String, List<Float>> featureMaps = new HashMap<>();
|
||||
if (featureMask == null || featureMask.length == 0) {
|
||||
LOGGER.severe("[Encrypt] feature mask is null, please check");
|
||||
return new int[0];
|
||||
return new HashMap<>();
|
||||
}
|
||||
LOGGER.info(String.format("[Encrypt] feature mask size: %s", featureMask.length));
|
||||
int featureSize = updateFeatureName.size();
|
||||
int[] featuresMap = new int[featureSize];
|
||||
int maskIndex = 0;
|
||||
for (int i = 0; i < featureSize; i++) {
|
||||
String key = updateFeatureName.get(i);
|
||||
float[] data = trainedMap.get(key);
|
||||
List<Float> featureMap = new ArrayList<>();
|
||||
LOGGER.info(String.format("[Encrypt] feature name: %s feature size: %s", key, data.length));
|
||||
for (int j = 0; j < data.length; j++) {
|
||||
float rawData = data[j];
|
||||
if (maskIndex >= featureMask.length) {
|
||||
LOGGER.severe("[Encrypt] the maskIndex is out of range for array featureMask, please check");
|
||||
return new int[0];
|
||||
return new HashMap<>();
|
||||
}
|
||||
float maskData = rawData * trainDataSize + featureMask[maskIndex];
|
||||
maskIndex += 1;
|
||||
data[j] = maskData;
|
||||
featureMap.add(maskData);
|
||||
}
|
||||
int featureName = builder.createString(key);
|
||||
int weight = FeatureMap.createDataVector(builder, data);
|
||||
int featureMap = FeatureMap.createFeatureMap(builder, featureName, weight);
|
||||
featuresMap[i] = featureMap;
|
||||
featureMaps.put(key, featureMap);
|
||||
}
|
||||
return featuresMap;
|
||||
return featureMaps;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -365,7 +364,9 @@ public class SecureProtocol {
|
|||
* @param trainDataSize tne size of train data set.
|
||||
* @return the serialized model weights after adding masks.
|
||||
*/
|
||||
public int[] dpMaskModel(FlatBufferBuilder builder, int trainDataSize, Map<String, float[]> trainedMap) {
|
||||
public Map<String, List<Float>> dpMaskModel(FlatBufferBuilder builder, int trainDataSize,
|
||||
Map<String, float[]> trainedMap) {
|
||||
Map<String, List<Float>> featureMaps = new HashMap<>();
|
||||
// get feature map
|
||||
Map<String, float[]> mapBeforeTrain = modelMap;
|
||||
int featureSize = updateFeatureName.size();
|
||||
|
@ -383,7 +384,7 @@ public class SecureProtocol {
|
|||
float rawData = data[j];
|
||||
if (j >= dataBeforeTrain.length) {
|
||||
LOGGER.severe("[Encrypt] the index j is out of range for array dataBeforeTrain, please check");
|
||||
return new int[0];
|
||||
return new HashMap<>();
|
||||
}
|
||||
float rawDataBeforeTrain = dataBeforeTrain[j];
|
||||
float updateData = rawData - rawDataBeforeTrain;
|
||||
|
@ -393,23 +394,23 @@ public class SecureProtocol {
|
|||
updateL2Norm = Math.sqrt(updateL2Norm);
|
||||
if (updateL2Norm == 0) {
|
||||
LOGGER.severe(Common.addTag("[Encrypt] updateL2Norm is 0, please check"));
|
||||
return new int[0];
|
||||
return new HashMap<>();
|
||||
}
|
||||
double clipFactor = Math.min(1.0, dpNormClip / updateL2Norm);
|
||||
|
||||
// clip and add noise
|
||||
int[] featuresMap = new int[featureSize];
|
||||
for (int i = 0; i < featureSize; i++) {
|
||||
String key = updateFeatureName.get(i);
|
||||
if (!trainedMap.containsKey(key)) {
|
||||
LOGGER.severe("[Encrypt] the key: " + key + " is not in map, please check!");
|
||||
return new int[0];
|
||||
return new HashMap<>();
|
||||
}
|
||||
float[] data = trainedMap.get(key);
|
||||
float[] data2 = new float[data.length];
|
||||
List<Float> featureMap = new ArrayList<>();
|
||||
if (!mapBeforeTrain.containsKey(key)) {
|
||||
LOGGER.severe("[Encrypt] the key: " + key + " is not in mapBeforeTrain, please check!");
|
||||
return new int[0];
|
||||
return new HashMap<>();
|
||||
}
|
||||
float[] dataBeforeTrain = mapBeforeTrain.get(key);
|
||||
|
||||
|
@ -419,7 +420,7 @@ public class SecureProtocol {
|
|||
float rawData = data[j];
|
||||
if (j >= dataBeforeTrain.length) {
|
||||
LOGGER.severe("[Encrypt] the index j is out of range for array dataBeforeTrain, please check");
|
||||
return new int[0];
|
||||
return new HashMap<>();
|
||||
}
|
||||
float rawDataBeforeTrain = dataBeforeTrain[j];
|
||||
float updateData = rawData - rawDataBeforeTrain;
|
||||
|
@ -432,13 +433,11 @@ public class SecureProtocol {
|
|||
updateData += gaussianNoise;
|
||||
data2[j] = rawDataBeforeTrain + updateData;
|
||||
data2[j] = data2[j] * trainDataSize;
|
||||
featureMap.add(data2[j]);
|
||||
}
|
||||
int featureName = builder.createString(key);
|
||||
int weight = FeatureMap.createDataVector(builder, data2);
|
||||
int featureMap = FeatureMap.createFeatureMap(builder, featureName, weight);
|
||||
featuresMap[i] = featureMap;
|
||||
featureMaps.put(key, featureMap);
|
||||
}
|
||||
return featuresMap;
|
||||
return featureMaps;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -18,6 +18,7 @@ package com.mindspore.flclient;
|
|||
|
||||
import com.google.flatbuffers.FlatBufferBuilder;
|
||||
|
||||
import com.mindspore.flclient.compression.DecodeExecutor;
|
||||
import com.mindspore.flclient.model.AlInferBert;
|
||||
import com.mindspore.flclient.model.AlTrainBert;
|
||||
import com.mindspore.flclient.model.Client;
|
||||
|
@ -29,6 +30,7 @@ import com.mindspore.flclient.model.TrainLenet;
|
|||
import com.mindspore.flclient.pki.PkiBean;
|
||||
import com.mindspore.flclient.pki.PkiUtil;
|
||||
|
||||
import mindspore.schema.*;
|
||||
import mindspore.schema.FLPlan;
|
||||
import mindspore.schema.FeatureMap;
|
||||
import mindspore.schema.RequestFLJob;
|
||||
|
@ -38,6 +40,7 @@ import mindspore.schema.ResponseFLJob;
|
|||
import java.io.IOException;
|
||||
import java.security.cert.Certificate;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||
|
@ -119,6 +122,7 @@ public class StartFLJob {
|
|||
.iteration(iteration)
|
||||
.signData(pkiBean.getSignData())
|
||||
.certificateChain(pkiBean.getCertificates())
|
||||
.downloadCompressTypesBuilder(flParameter.getDownloadCompressTypes())
|
||||
.build();
|
||||
}
|
||||
return builder.flName(flParameter.getFlName())
|
||||
|
@ -126,6 +130,7 @@ public class StartFLJob {
|
|||
.id(localFLParameter.getFlID())
|
||||
.dataSize(dataSize)
|
||||
.iteration(iteration)
|
||||
.downloadCompressTypesBuilder(flParameter.getDownloadCompressTypes())
|
||||
.build();
|
||||
}
|
||||
|
||||
|
@ -151,8 +156,9 @@ public class StartFLJob {
|
|||
ArrayList<FeatureMap> albertFeatureMaps = new ArrayList<FeatureMap>();
|
||||
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
|
||||
featureSize = 0;
|
||||
for (int i = 0; i < fmCount; i++) {
|
||||
FeatureMap feature = flJob.featureMap(i);
|
||||
List<FeatureMap> featureMapList = parseFeatureMapList(flJob);
|
||||
for (int i = 0; i < featureMapList.size(); i++) {
|
||||
FeatureMap feature = featureMapList.get(i);
|
||||
if (feature == null) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
|
||||
return FLClientStatus.FAILED;
|
||||
|
@ -233,12 +239,14 @@ public class StartFLJob {
|
|||
|
||||
private FLClientStatus deprecatedParseResponseLenet(ResponseFLJob flJob) {
|
||||
FLClientStatus status;
|
||||
int fmCount = flJob.featureMapLength();
|
||||
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
|
||||
updateFeatureName.clear();
|
||||
featureSize = 0;
|
||||
for (int i = 0; i < fmCount; i++) {
|
||||
FeatureMap feature = flJob.featureMap(i);
|
||||
List<FeatureMap> featureMapList = parseFeatureMapList(flJob);
|
||||
|
||||
ArrayList<FeatureMap> featureMaps = new ArrayList<>();
|
||||
|
||||
for (int i = 0; i < featureMapList.size(); i++) {
|
||||
FeatureMap feature = featureMapList.get(i);
|
||||
if (feature == null) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
|
||||
return FLClientStatus.FAILED;
|
||||
|
@ -267,6 +275,24 @@ public class StartFLJob {
|
|||
return FLClientStatus.SUCCESS;
|
||||
}
|
||||
|
||||
private List<FeatureMap> parseFeatureMapList(ResponseFLJob flJob) {
|
||||
List<FeatureMap> featureMaps;
|
||||
byte compressType = flJob.downloadCompressType();
|
||||
if (flJob.downloadCompressType() == mindspore.schema.CompressType.NO_COMPRESS) {
|
||||
LOGGER.info(Common.addTag("[parseFeatureMapList] create no compress feature map."));
|
||||
featureMaps = new ArrayList<>();
|
||||
for (int i = 0; i < flJob.featureMapLength(); i++) {
|
||||
featureMaps.add(flJob.featureMap(i));
|
||||
}
|
||||
} else {
|
||||
List<CompressFeatureMap> compressFeatureMapList = new ArrayList<>();
|
||||
for (int i = 0; i < flJob.compressFeatureMapLength(); i++) {
|
||||
compressFeatureMapList.add(flJob.compressFeatureMap(i));
|
||||
}
|
||||
featureMaps = DecodeExecutor.getInstance().deCompressWeight(compressType, compressFeatureMapList);
|
||||
}
|
||||
return featureMaps;
|
||||
}
|
||||
|
||||
private FLClientStatus hybridFeatures(ResponseFLJob flJob) {
|
||||
FLClientStatus status;
|
||||
|
@ -275,8 +301,23 @@ public class StartFLJob {
|
|||
ArrayList<FeatureMap> trainFeatureMaps = new ArrayList<FeatureMap>();
|
||||
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
|
||||
featureSize = 0;
|
||||
List<FeatureMap> featureMaps;
|
||||
byte compressType = flJob.downloadCompressType();
|
||||
if (compressType == CompressType.NO_COMPRESS) {
|
||||
featureMaps = new ArrayList<>();
|
||||
for (int i = 0; i < fmCount; i++) {
|
||||
featureMaps.add(flJob.featureMap(i));
|
||||
}
|
||||
} else {
|
||||
List<CompressFeatureMap> compressFeatureMapList = new ArrayList<>();
|
||||
for (int i = 0; i < flJob.compressFeatureMapLength(); i++) {
|
||||
compressFeatureMapList.add(flJob.compressFeatureMap(i));
|
||||
}
|
||||
featureMaps = DecodeExecutor.getInstance().deCompressWeight(compressType, compressFeatureMapList);
|
||||
fmCount = featureMaps.size();
|
||||
}
|
||||
for (int i = 0; i < fmCount; i++) {
|
||||
FeatureMap feature = flJob.featureMap(i);
|
||||
FeatureMap feature = featureMaps.get(i);
|
||||
if (feature == null) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
|
||||
retCode = ResponseCode.SystemError;
|
||||
|
@ -335,8 +376,23 @@ public class StartFLJob {
|
|||
int fmCount = flJob.featureMapLength();
|
||||
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
|
||||
featureSize = 0;
|
||||
byte compressType = flJob.downloadCompressType();
|
||||
List<FeatureMap> parseFeatureMaps;
|
||||
if (compressType == CompressType.NO_COMPRESS) {
|
||||
parseFeatureMaps = new ArrayList<>();
|
||||
for (int i = 0; i < fmCount; i++) {
|
||||
parseFeatureMaps.add(flJob.featureMap(i));
|
||||
}
|
||||
} else {
|
||||
List<CompressFeatureMap> compressFeatureMapList = new ArrayList<>();
|
||||
for (int i = 0; i < flJob.compressFeatureMapLength(); i++) {
|
||||
compressFeatureMapList.add(flJob.compressFeatureMap(i));
|
||||
}
|
||||
parseFeatureMaps = DecodeExecutor.getInstance().deCompressWeight(compressType, compressFeatureMapList);
|
||||
fmCount = parseFeatureMaps.size();
|
||||
}
|
||||
for (int i = 0; i < fmCount; i++) {
|
||||
FeatureMap feature = flJob.featureMap(i);
|
||||
FeatureMap feature = parseFeatureMaps.get(i);
|
||||
if (feature == null) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
|
||||
retCode = ResponseCode.SystemError;
|
||||
|
@ -437,8 +493,8 @@ public class StartFLJob {
|
|||
|
||||
switch (responseRetCode) {
|
||||
case (ResponseCode.SUCCEED):
|
||||
if (flJob.featureMapLength() <= 0) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the feature size get from server is zero"));
|
||||
if (flJob.downloadCompressType() == CompressType.NO_COMPRESS && flJob.featureMapLength() <= 0) {
|
||||
LOGGER.warning(Common.addTag("[startFLJob] the feature size get from server is zero"));
|
||||
retCode = ResponseCode.SystemError;
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
|
@ -484,6 +540,7 @@ public class StartFLJob {
|
|||
private int equipCertOffset = 0;
|
||||
private int equipCACertOffset = 0;
|
||||
private int rootCertOffset = 0;
|
||||
private int downloadCompressTypesOffset = 0;
|
||||
|
||||
public RequestStartFLJobBuilder() {
|
||||
builder = new FlatBufferBuilder();
|
||||
|
@ -598,6 +655,17 @@ public class StartFLJob {
|
|||
return this;
|
||||
}
|
||||
|
||||
private RequestStartFLJobBuilder downloadCompressTypesBuilder(byte[] downloadCompressTypes) {
|
||||
if (downloadCompressTypes == null || downloadCompressTypes.length == 0) {
|
||||
LOGGER.severe(Common.addTag("[StartFLJob] the parameter of <downloadCompressTypes> is null or empty," +
|
||||
" please check!"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
this.downloadCompressTypesOffset = RequestFLJob.createDownloadCompressTypesVector(builder,
|
||||
downloadCompressTypes);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* build protobuffer
|
||||
*
|
||||
|
@ -615,6 +683,7 @@ public class StartFLJob {
|
|||
RequestFLJob.addEquipCaCert(builder, equipCACertOffset);
|
||||
RequestFLJob.addEquipCert(builder, equipCertOffset);
|
||||
RequestFLJob.addKeyAttestation(builder, keyAttestationOffset);
|
||||
RequestFLJob.addDownloadCompressTypes(builder, downloadCompressTypesOffset);
|
||||
int root = RequestFLJob.endRequestFLJob(builder);
|
||||
builder.finish(root);
|
||||
return builder.sizedByteArray();
|
||||
|
|
|
@ -147,6 +147,10 @@ public class SyncFLJob {
|
|||
LOGGER.info(Common.addTag("[startFLJob] startFLJob succeed, curIteration: " + flLiteClient.getIteration()));
|
||||
updateTryTimePerIter(flLiteClient);
|
||||
|
||||
// Copy weights before training.
|
||||
Map<String, float[]> oldFeatureMap = flLiteClient.getFeatureMap();
|
||||
localFLParameter.setOldFeatureMap(oldFeatureMap);
|
||||
|
||||
// create mask
|
||||
curStatus = flLiteClient.getFeatureMask();
|
||||
if (curStatus == FLClientStatus.RESTART) {
|
||||
|
|
|
@ -26,11 +26,15 @@ import com.mindspore.flclient.model.SessionUtil;
|
|||
import com.mindspore.flclient.model.Status;
|
||||
import com.mindspore.flclient.model.TrainLenet;
|
||||
import com.mindspore.lite.MSTensor;
|
||||
import com.mindspore.flclient.compression.EncodeExecutor;
|
||||
import com.mindspore.flclient.compression.CompressWeight;
|
||||
|
||||
import mindspore.schema.FeatureMap;
|
||||
import mindspore.schema.CompressFeatureMap;
|
||||
import mindspore.schema.RequestUpdateModel;
|
||||
import mindspore.schema.ResponseCode;
|
||||
import mindspore.schema.ResponseUpdateModel;
|
||||
import static mindspore.schema.CompressType.NO_COMPRESS;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Date;
|
||||
|
@ -208,6 +212,7 @@ public class UpdateModel {
|
|||
private RequestUpdateModel requestUM;
|
||||
private FlatBufferBuilder builder;
|
||||
private int fmOffset = 0;
|
||||
private int compFmOffset = 0;
|
||||
private int nameOffset = 0;
|
||||
private int idOffset = 0;
|
||||
private int timestampOffset = 0;
|
||||
|
@ -215,8 +220,11 @@ public class UpdateModel {
|
|||
private int sign = 0;
|
||||
private int indexArrayOffset = 0;
|
||||
private int iteration = 0;
|
||||
private byte uploadCompressType = 0;
|
||||
private float uploadSparseRate = 0.0f;
|
||||
private EncryptLevel encryptLevel = EncryptLevel.NOT_ENCRYPT;
|
||||
private float uploadLossOffset = 0.0f;
|
||||
private int nameVecOffset = 0;
|
||||
|
||||
private RequestUpdateModelBuilder(EncryptLevel encryptLevel) {
|
||||
builder = new FlatBufferBuilder();
|
||||
|
@ -294,34 +302,33 @@ public class UpdateModel {
|
|||
} else {
|
||||
trainedMap = getFeatureMap();
|
||||
}
|
||||
Map<String, List<Float>> featureMaps = new HashMap<>();
|
||||
long startTime;
|
||||
long endTime;
|
||||
switch (encryptLevel) {
|
||||
case PW_ENCRYPT:
|
||||
int[] fmOffsetsPW = secureProtocol.pwMaskModel(builder, trainDataSize, trainedMap);
|
||||
if (fmOffsetsPW == null || fmOffsetsPW.length == 0) {
|
||||
LOGGER.severe("[Encrypt] the return fmOffsetsPW from <secureProtocol.pwMaskModel> is " +
|
||||
featureMaps = secureProtocol.pwMaskModel(builder, trainDataSize, trainedMap);
|
||||
if (featureMaps == null || featureMaps.size() == 0) {
|
||||
LOGGER.severe("[Encrypt] the return featureMaps from <secureProtocol.pwMaskModel> is " +
|
||||
"null, please check");
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsPW);
|
||||
LOGGER.info(Common.addTag("[Encrypt] pairwise mask model ok!"));
|
||||
return this;
|
||||
break;
|
||||
case DP_ENCRYPT:
|
||||
startTime = System.currentTimeMillis();
|
||||
int[] fmOffsetsDP = secureProtocol.dpMaskModel(builder, trainDataSize, trainedMap);
|
||||
if (fmOffsetsDP == null || fmOffsetsDP.length == 0) {
|
||||
LOGGER.severe("[Encrypt] the return fmOffsetsDP from <secureProtocol.dpMaskModel> is " +
|
||||
featureMaps = secureProtocol.dpMaskModel(builder, trainDataSize, trainedMap);
|
||||
if (featureMaps == null || featureMaps.size() == 0) {
|
||||
LOGGER.severe("[Encrypt] the return featureMaps from <secureProtocol.dpMaskModel> is " +
|
||||
"null, please check");
|
||||
retCode = ResponseCode.RequestError;
|
||||
status = FLClientStatus.FAILED;
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsDP);
|
||||
LOGGER.info(Common.addTag("[Encrypt] DP mask model ok!"));
|
||||
endTime = System.currentTimeMillis();
|
||||
LOGGER.info(Common.addTag("[Encrypt] dp time is: " + (endTime - startTime) + "ms"));
|
||||
return this;
|
||||
LOGGER.info(Common.addTag("dp time is " + (endTime - startTime) + "ms"));
|
||||
break;
|
||||
case SIGNDS:
|
||||
startTime = System.currentTimeMillis();
|
||||
// signds alg return indexArray, and package indexArray into flatbuffer.
|
||||
|
@ -352,31 +359,104 @@ public class UpdateModel {
|
|||
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsSignds);
|
||||
LOGGER.info(Common.addTag("[Encrypt] SignDS mask model ok!"));
|
||||
endTime = System.currentTimeMillis();
|
||||
LOGGER.info(Common.addTag("[Encrypt] signds time is: " + (endTime - startTime) + "ms"));
|
||||
LOGGER.info(Common.addTag("signds time is " + (endTime - startTime) + "ms"));
|
||||
return this;
|
||||
case NOT_ENCRYPT:
|
||||
default:
|
||||
startTime = System.currentTimeMillis();
|
||||
int featureSize = updateFeatureName.size();
|
||||
int[] fmOffsets = new int[featureSize];
|
||||
for (int i = 0; i < featureSize; i++) {
|
||||
String key = updateFeatureName.get(i);
|
||||
float[] data = trainedMap.get(key);
|
||||
LOGGER.info(Common.addTag("[updateModel build featuresMap] feature name: " + key + " feature " +
|
||||
"size: " + data.length));
|
||||
for (int j = 0; j < data.length; j++) {
|
||||
data[j] = data[j] * trainDataSize;
|
||||
for (String name : updateFeatureName) {
|
||||
float[] data = trainedMap.get(name);
|
||||
List<Float> featureMap = new ArrayList<>();
|
||||
for (float datum : data) {
|
||||
featureMap.add(datum * (float) trainDataSize);
|
||||
}
|
||||
int featureName = builder.createString(key);
|
||||
int weight = FeatureMap.createDataVector(builder, data);
|
||||
int featureMap = FeatureMap.createFeatureMap(builder, featureName, weight);
|
||||
fmOffsets[i] = featureMap;
|
||||
featureMaps.put(name, featureMap);
|
||||
}
|
||||
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsets);
|
||||
endTime = System.currentTimeMillis();
|
||||
LOGGER.info(Common.addTag("[Encrypt] not encrypt time is: " + (endTime - startTime) + "ms"));
|
||||
return this;
|
||||
LOGGER.info(Common.addTag("not encrypt time is " + (endTime - startTime) + "ms"));
|
||||
break;
|
||||
}
|
||||
byte uploadCompressType = localFLParameter.getUploadCompressType();
|
||||
if (uploadCompressType != NO_COMPRESS) {
|
||||
startTime = System.currentTimeMillis();
|
||||
this.compFmOffset = buildCompFmOffset(featureMaps, trainDataSize);
|
||||
this.uploadCompressType = localFLParameter.getUploadCompressType();
|
||||
this.uploadSparseRate = localFLParameter.getUploadSparseRatio();
|
||||
this.nameVecOffset = buildNameVecOffset(updateFeatureName);
|
||||
endTime = System.currentTimeMillis();
|
||||
LOGGER.info(Common.addTag("compression time is " + (endTime - startTime) + "ms"));
|
||||
return this;
|
||||
}
|
||||
this.fmOffset = buildFmOffset(featureMaps, updateFeatureName);
|
||||
return this;
|
||||
}
|
||||
|
||||
private int buildCompFmOffset(Map<String, List<Float>> featureMaps, int trainDataSize) {
|
||||
List<CompressWeight> compressWeights = EncodeExecutor.getInstance().encode(featureMaps, trainDataSize);
|
||||
if (compressWeights == null || compressWeights.size() == 0) {
|
||||
LOGGER.severe("[Compression] the return compressWeights from <encodeExecutor.encode> is " +
|
||||
"null, please check");
|
||||
retCode = ResponseCode.RequestError;
|
||||
status = FLClientStatus.FAILED;
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
int compFeatureSize = compressWeights.size();
|
||||
int[] compFmOffsets = new int[compFeatureSize];
|
||||
int index = 0;
|
||||
for (CompressWeight compressWeight : compressWeights) {
|
||||
String weightFullname = compressWeight.getWeightFullname();
|
||||
List<Byte> compressData = compressWeight.getCompressData();
|
||||
float minVal = compressWeight.getMinValue();
|
||||
float maxVal = compressWeight.getMaxValue();
|
||||
byte[] data = new byte[compressData.size()];
|
||||
LOGGER.info(Common.addTag("[updateModel build compressWeight] feature name: "
|
||||
+ weightFullname + ", feature size: " + data.length));
|
||||
for (int j = 0; j < data.length; j++) {
|
||||
data[j] = compressData.get(j);
|
||||
}
|
||||
int featureName = builder.createString(weightFullname);
|
||||
int weight = CompressFeatureMap.createCompressDataVector(builder, data);
|
||||
int featureMap = CompressFeatureMap.createCompressFeatureMap(builder, featureName, weight,
|
||||
minVal, maxVal);
|
||||
LOGGER.info(Common.addTag("[Compression]" +
|
||||
" featureName: " + weightFullname +
|
||||
", min_val: " + minVal +
|
||||
", max_val: " + maxVal));
|
||||
compFmOffsets[index] = featureMap;
|
||||
index += 1;
|
||||
}
|
||||
return RequestUpdateModel.createCompressFeatureMapVector(builder, compFmOffsets);
|
||||
}
|
||||
|
||||
private int buildNameVecOffset(ArrayList<String> updateFeatureName) {
|
||||
int featureSize = updateFeatureName.size();
|
||||
int[] nameVecOffsets = new int[featureSize];
|
||||
for (int i = 0; i < featureSize; i++) {
|
||||
String key = updateFeatureName.get(i);
|
||||
int featureName = builder.createString(key);
|
||||
nameVecOffsets[i] = featureName;
|
||||
}
|
||||
return RequestUpdateModel.createNameVecVector(builder, nameVecOffsets);
|
||||
}
|
||||
|
||||
private int buildFmOffset(Map<String, List<Float>> featureMaps, ArrayList<String> updateFeatureName) {
|
||||
int featureSize = updateFeatureName.size();
|
||||
int[] fmOffsets = new int[featureSize];
|
||||
for (int i = 0; i < featureSize; i++) {
|
||||
String key = updateFeatureName.get(i);
|
||||
List<Float> featureMap = featureMaps.get(key);
|
||||
float[] data = new float[featureMap.size()];
|
||||
LOGGER.info(Common.addTag("[updateModel build featuresMap] feature name: " + key + " feature " +
|
||||
"size: " + data.length));
|
||||
for (int j = 0; j < data.length; j++) {
|
||||
data[j] = featureMap.get(j);
|
||||
}
|
||||
int featureName = builder.createString(key);
|
||||
int weight = FeatureMap.createDataVector(builder, data);
|
||||
int featureMapOff = FeatureMap.createFeatureMap(builder, featureName, weight);
|
||||
fmOffsets[i] = featureMapOff;
|
||||
}
|
||||
return RequestUpdateModel.createFeatureMapVector(builder, fmOffsets);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -417,6 +497,10 @@ public class UpdateModel {
|
|||
RequestUpdateModel.addFlId(this.builder, idOffset);
|
||||
RequestUpdateModel.addTimestamp(builder, this.timestampOffset);
|
||||
RequestUpdateModel.addIteration(builder, this.iteration);
|
||||
RequestUpdateModel.addCompressFeatureMap(builder, this.compFmOffset);
|
||||
RequestUpdateModel.addUploadCompressType(builder, this.uploadCompressType);
|
||||
RequestUpdateModel.addUploadSparseRate(builder, this.uploadSparseRate);
|
||||
RequestUpdateModel.addNameVec(builder, this.nameVecOffset);
|
||||
RequestUpdateModel.addFeatureMap(builder, this.fmOffset);
|
||||
RequestUpdateModel.addSignature(builder, this.signDataOffset);
|
||||
RequestUpdateModel.addUploadLoss(builder, this.uploadLossOffset);
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2022. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient.compression;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static mindspore.schema.CompressType.NO_COMPRESS;
|
||||
import static mindspore.schema.CompressType.QUANT;
|
||||
|
||||
/**
|
||||
* The compress mod.
|
||||
*
|
||||
* @since 2021-12-21
|
||||
*/
|
||||
|
||||
public class CompressMode {
|
||||
// compress type -> num bits
|
||||
public static final Map<Byte, Integer> COMPRESS_TYPE_MAP = new HashMap<>();
|
||||
|
||||
static {
|
||||
COMPRESS_TYPE_MAP.put(NO_COMPRESS, -1);
|
||||
COMPRESS_TYPE_MAP.put(QUANT, 8);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,83 @@
|
|||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2022. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient.compression;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Compress Weight Bean
|
||||
*
|
||||
* @since 2021-12-21
|
||||
*/
|
||||
public class CompressWeight {
|
||||
private String weightFullname;
|
||||
private List<Byte> compressData;
|
||||
private float minValue;
|
||||
private float maxValue;
|
||||
|
||||
public CompressWeight() {
|
||||
}
|
||||
|
||||
public CompressWeight(String weightFullname, List<Byte> compressData, float minValue, float maxValue) {
|
||||
this.weightFullname = weightFullname;
|
||||
this.compressData = compressData;
|
||||
this.minValue = minValue;
|
||||
this.maxValue = maxValue;
|
||||
}
|
||||
|
||||
public String getWeightFullname() {
|
||||
return weightFullname;
|
||||
}
|
||||
|
||||
public void setWeightFullname(String weightFullname) {
|
||||
this.weightFullname = weightFullname;
|
||||
}
|
||||
|
||||
public List<Byte> getCompressData() {
|
||||
return compressData;
|
||||
}
|
||||
|
||||
public void setCompressData(List<Byte> compressData) {
|
||||
this.compressData = compressData;
|
||||
}
|
||||
|
||||
public float getMinValue() {
|
||||
return minValue;
|
||||
}
|
||||
|
||||
public void setMinValue(float minValue) {
|
||||
this.minValue = minValue;
|
||||
}
|
||||
|
||||
public float getMaxValue() {
|
||||
return maxValue;
|
||||
}
|
||||
|
||||
public void setMaxValue(float maxValue) {
|
||||
this.maxValue = maxValue;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "CompressWeight{" +
|
||||
"weightFullname='" + weightFullname + '\'' +
|
||||
", compressData=" + compressData +
|
||||
", minValue=" + minValue +
|
||||
", maxValue=" + maxValue +
|
||||
'}';
|
||||
}
|
||||
}
|
|
@ -0,0 +1,115 @@
|
|||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2022. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient.compression;
|
||||
|
||||
import com.google.flatbuffers.FlatBufferBuilder;
|
||||
|
||||
import com.mindspore.flclient.Common;
|
||||
import com.mindspore.flclient.StartFLJob;
|
||||
import mindspore.schema.CompressFeatureMap;
|
||||
import mindspore.schema.FeatureMap;
|
||||
import mindspore.schema.CompressType;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
import static mindspore.schema.CompressType.QUANT;
|
||||
/**
|
||||
* Compress Executor
|
||||
*
|
||||
* @since 2021-12-21
|
||||
*/
|
||||
public class DecodeExecutor {
|
||||
private static final Logger LOGGER = Logger.getLogger(DecodeExecutor.class.toString());
|
||||
|
||||
private static volatile DecodeExecutor compressExecutor;
|
||||
|
||||
private DecodeExecutor() {}
|
||||
|
||||
public static DecodeExecutor getInstance() {
|
||||
if (compressExecutor == null) {
|
||||
synchronized (DecodeExecutor.class) {
|
||||
if (compressExecutor == null) {
|
||||
compressExecutor = new DecodeExecutor();
|
||||
}
|
||||
}
|
||||
}
|
||||
return compressExecutor;
|
||||
}
|
||||
|
||||
public List<FeatureMap> deCompressWeight(byte compressType, List<CompressFeatureMap> compressFeatureMapList) {
|
||||
if (!CompressMode.COMPRESS_TYPE_MAP.containsKey(compressType)) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
LOGGER.info(Common.addTag("[deCompressWeight] create " + CompressType.name(compressType) + " feature map."));
|
||||
int num_bits = CompressMode.COMPRESS_TYPE_MAP.get(compressType);
|
||||
if (compressType == QUANT) {
|
||||
return deCompressQuantMinMax(compressFeatureMapList, num_bits);
|
||||
}
|
||||
return new ArrayList<>();
|
||||
}
|
||||
|
||||
private List<FeatureMap> deCompressQuantMinMax(List<CompressFeatureMap> compressFeatureMapList, int num_bits) {
|
||||
float temp1 = (float) (Math.pow(2, num_bits) - 1);
|
||||
float temp2 = (float) Math.pow(2, num_bits - 1);
|
||||
|
||||
Map<String, float[]> deCompressFeatureMaps = new HashMap<>();
|
||||
int compressFeatureMapLength = compressFeatureMapList.size();
|
||||
for (int i = 0; i < compressFeatureMapLength; i++) {
|
||||
CompressFeatureMap compressFeatureMap = compressFeatureMapList.get(i);
|
||||
String weightName = compressFeatureMap.weightFullname();
|
||||
int compressDataLength = compressFeatureMap.compressDataLength();
|
||||
List<Byte> compressWeightList = new ArrayList<>();
|
||||
for (int j = 0; j < compressDataLength; j++) {
|
||||
compressWeightList.add(compressFeatureMap.compressData(j));
|
||||
}
|
||||
float minVal = compressFeatureMap.minVal();
|
||||
float maxVal = compressFeatureMap.maxVal();
|
||||
float scale_value = (float) ((maxVal - minVal) / temp1 + 1e-10);
|
||||
float[] params = new float[compressWeightList.size()];
|
||||
for (int j = 0; j < params.length; j++) {
|
||||
float val = (compressWeightList.get(j).intValue() + temp2) * scale_value + minVal;
|
||||
params[j] = val;
|
||||
}
|
||||
deCompressFeatureMaps.put(weightName, params);
|
||||
}
|
||||
|
||||
List<FeatureMap> featureMaps = new ArrayList<>();
|
||||
for (String weightName : deCompressFeatureMaps.keySet()) {
|
||||
FlatBufferBuilder builder = new FlatBufferBuilder(0);
|
||||
int weightFullnameOffset = builder.createString(weightName);
|
||||
float[] data = deCompressFeatureMaps.get(weightName);
|
||||
int dataOffset = FeatureMap.createDataVector(builder, data);
|
||||
|
||||
FeatureMap.startFeatureMap(builder);
|
||||
FeatureMap.addWeightFullname(builder, weightFullnameOffset);
|
||||
FeatureMap.addData(builder, dataOffset);
|
||||
|
||||
int orc = FeatureMap.endFeatureMap(builder);
|
||||
builder.finish(orc);
|
||||
ByteBuffer buf = builder.dataBuffer();
|
||||
FeatureMap featureMap = FeatureMap.getRootAsFeatureMap(buf);
|
||||
|
||||
featureMaps.add(featureMap);
|
||||
}
|
||||
return featureMaps;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,167 @@
|
|||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2022. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient.compression;
|
||||
|
||||
import com.mindspore.flclient.LocalFLParameter;
|
||||
import static mindspore.schema.CompressType.DIFF_SPARSE_QUANT;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Queue;
|
||||
import java.util.PriorityQueue;
|
||||
|
||||
/**
|
||||
* Encode Executor
|
||||
*
|
||||
* @since 2021-12-21
|
||||
*/
|
||||
public class EncodeExecutor {
|
||||
private final LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
||||
|
||||
private static volatile EncodeExecutor encodeExecutor;
|
||||
|
||||
private EncodeExecutor() {}
|
||||
|
||||
public static EncodeExecutor getInstance() {
|
||||
if (encodeExecutor == null) {
|
||||
synchronized (EncodeExecutor.class) {
|
||||
if (encodeExecutor == null) {
|
||||
encodeExecutor = new EncodeExecutor();
|
||||
}
|
||||
}
|
||||
}
|
||||
return encodeExecutor;
|
||||
}
|
||||
|
||||
private static final int multiplier = 2147483647;
|
||||
private static final double increment = 4294967294.0;
|
||||
private static final int modulo = 48271;
|
||||
|
||||
private List<Integer> constructMaskArray(int paramNum) {
|
||||
int seed = localFLParameter.getSeed();
|
||||
float uploadSparseRatio = localFLParameter.getUploadSparseRatio();
|
||||
|
||||
List<Integer> maskArray = new ArrayList<>();
|
||||
|
||||
int retain_num = (int) ((float) (paramNum) * uploadSparseRatio);
|
||||
for (int i = 0; i < retain_num; ++i) {
|
||||
maskArray.add(1);
|
||||
}
|
||||
for (int i = retain_num; i < paramNum; ++i) {
|
||||
maskArray.add(0);
|
||||
}
|
||||
|
||||
seed = ((seed + multiplier) * modulo) % multiplier;
|
||||
for (int i = 0; i < paramNum; ++i) {
|
||||
// generate random number in (0, 1)
|
||||
double rand = (double)(seed) / increment + 0.5;
|
||||
// update seed
|
||||
seed = (seed * modulo) % multiplier;
|
||||
|
||||
int j = (int)(rand * (double)(paramNum - i)) + i;
|
||||
int temp = maskArray.get(i);
|
||||
maskArray.set(i, maskArray.get(j));
|
||||
maskArray.set(j, temp);
|
||||
}
|
||||
return maskArray;
|
||||
}
|
||||
|
||||
public List<CompressWeight> enDiffSparseQuant(Map<String, List<Float>> featureMaps, int numBits,
|
||||
int trainDataSize) {
|
||||
List<CompressWeight> compressWeights = new ArrayList<>();
|
||||
|
||||
// difference encode
|
||||
Map<String, float[]> oldFeatureMap = localFLParameter.getOldFeatureMap();
|
||||
Map<String, List<Float>> diffFeatureMaps = new HashMap<>();
|
||||
for (String featureMapName : featureMaps.keySet()) {
|
||||
List<Float> diffs = new ArrayList<>();
|
||||
List<Float> featureMap = featureMaps.get(featureMapName);
|
||||
float[] dataBeforeTrain = oldFeatureMap.get(featureMapName);
|
||||
int length = dataBeforeTrain.length;
|
||||
for (int i = 0; i < length; ++i) {
|
||||
float diff = featureMap.get(i) - dataBeforeTrain[i] * (float) trainDataSize;
|
||||
diffs.add(diff);
|
||||
}
|
||||
diffFeatureMaps.put(featureMapName, diffs);
|
||||
}
|
||||
|
||||
// sparse encode
|
||||
int paramNum = 0;
|
||||
for (String featureMapName : diffFeatureMaps.keySet()) {
|
||||
int weightSize = diffFeatureMaps.get(featureMapName).size();
|
||||
paramNum += weightSize;
|
||||
}
|
||||
List<Integer> maskArray = constructMaskArray(paramNum);
|
||||
|
||||
Map<String, List<Float>> sparseFeatureMaps = new HashMap<>();
|
||||
int index = 0;
|
||||
for (String featureMapName : diffFeatureMaps.keySet()) {
|
||||
List<Float> sparseFeatureMap = new ArrayList<>();
|
||||
List<Float> Weight = diffFeatureMaps.get(featureMapName);
|
||||
for (Float dataValue : Weight) {
|
||||
if (maskArray.get(index) == 1) {
|
||||
sparseFeatureMap.add(dataValue);
|
||||
}
|
||||
index += 1;
|
||||
}
|
||||
sparseFeatureMaps.put(featureMapName, sparseFeatureMap);
|
||||
}
|
||||
|
||||
// quant encode
|
||||
float temp1 = (float) (1 << numBits) - 1.0f;
|
||||
float temp2 = (float) (1 << (numBits - 1));
|
||||
for (String featureMapName : sparseFeatureMaps.keySet()) {
|
||||
CompressWeight compressWeight = new CompressWeight();
|
||||
compressWeight.setWeightFullname(featureMapName);
|
||||
|
||||
List<Float> sparseFeatureMap = sparseFeatureMaps.get(featureMapName);
|
||||
|
||||
// get min and max value
|
||||
Float minVal = Float.MAX_VALUE;
|
||||
float maxVal = -minVal;
|
||||
for (Float value : sparseFeatureMap) {
|
||||
if (value < minVal) {
|
||||
minVal = value;
|
||||
}
|
||||
if (value > maxVal) {
|
||||
maxVal = value;
|
||||
}
|
||||
}
|
||||
compressWeight.setMinValue(minVal);
|
||||
compressWeight.setMaxValue(maxVal);
|
||||
float scale_value = (maxVal - minVal) / temp1 + 1e-10f;
|
||||
List<Byte> compressData = new ArrayList<>();
|
||||
for (Float aFloat : sparseFeatureMap) {
|
||||
compressData.add((byte) (Math.round((aFloat - minVal) / scale_value - temp2)));
|
||||
}
|
||||
compressWeight.setCompressData(compressData);
|
||||
compressWeights.add(compressWeight);
|
||||
}
|
||||
|
||||
return compressWeights;
|
||||
}
|
||||
|
||||
public List<CompressWeight> encode(Map<String, List<Float>> featureMaps, int trainDataSize) {
|
||||
byte uploadCompressType = localFLParameter.getUploadCompressType();
|
||||
if (uploadCompressType == DIFF_SPARSE_QUANT) {
|
||||
return enDiffSparseQuant(featureMaps, 8, trainDataSize);
|
||||
}
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
}
|
|
@ -15,8 +15,9 @@
|
|||
"""Context for parameter server training mode"""
|
||||
|
||||
import os
|
||||
from mindspore._checkparam import Validator
|
||||
from mindspore._checkparam import Validator, Rel
|
||||
from mindspore._c_expression import PSContext
|
||||
from mindspore import log as logger
|
||||
|
||||
_ps_context = None
|
||||
|
||||
|
@ -79,6 +80,9 @@ _set_ps_context_func_map = {
|
|||
"sign_global_lr": ps_context().set_sign_global_lr,
|
||||
"sign_dim_out": ps_context().set_sign_dim_out,
|
||||
"checkpoint_dir": ps_context().set_checkpoint_dir,
|
||||
"upload_compress_type": ps_context().set_upload_compress_type,
|
||||
"upload_sparse_rate": ps_context().set_upload_sparse_rate,
|
||||
"download_compress_type": ps_context().set_download_compress_type,
|
||||
}
|
||||
|
||||
_get_ps_context_func_map = {
|
||||
|
@ -126,7 +130,10 @@ _get_ps_context_func_map = {
|
|||
"sign_thr_ratio": ps_context().sign_thr_ratio,
|
||||
"sign_global_lr": ps_context().sign_global_lr,
|
||||
"sign_dim_out": ps_context().sign_dim_out,
|
||||
"checkpoint_dir": ps_context().checkpoint_dir
|
||||
"checkpoint_dir": ps_context().checkpoint_dir,
|
||||
"upload_compress_type": ps_context().upload_compress_type,
|
||||
"upload_sparse_rate": ps_context().upload_sparse_rate,
|
||||
"download_compress_type": ps_context().download_compress_type,
|
||||
}
|
||||
|
||||
_check_positive_int_keys = ["server_num", "scheduler_port", "fl_server_port",
|
||||
|
@ -140,6 +147,15 @@ _check_positive_float_keys = ["update_model_ratio", "client_learning_rate"]
|
|||
|
||||
_check_port_keys = ["scheduler_port", "fl_server_port"]
|
||||
|
||||
_check_string_keys = {
|
||||
"upload_compress_type": ["NO_COMPRESS", "DIFF_SPARSE_QUANT"],
|
||||
"download_compress_type": ["NO_COMPRESS", "QUANT"],
|
||||
}
|
||||
|
||||
_check_float_range_keys = {
|
||||
"upload_sparse_rate": {"lower_limit": 0.0, "upper_limit": 1.0, "rel": Rel.INC_RIGHT},
|
||||
}
|
||||
|
||||
def _get_ps_mode_rank():
|
||||
ps_rank = ps_context().ps_rank_id()
|
||||
if ps_rank == -1:
|
||||
|
@ -183,6 +199,7 @@ def _set_ps_context(**kwargs):
|
|||
Examples:
|
||||
>>> context.set_ps_context(enable_ps=True, enable_ssl=True, client_password='123456', server_password='123456')
|
||||
"""
|
||||
kwargs = _check_conflict_value(kwargs)
|
||||
for key, value in kwargs.items():
|
||||
if key not in _set_ps_context_func_map:
|
||||
raise ValueError("Set PS context keyword %s is not recognized!" % key)
|
||||
|
@ -287,6 +304,31 @@ def _check_value(key, value):
|
|||
if key in _check_positive_float_keys:
|
||||
Validator.check_positive_float(value, key)
|
||||
|
||||
if key in _check_string_keys:
|
||||
try:
|
||||
string_keys = _check_string_keys[key]
|
||||
Validator.check_string(value, string_keys)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if key in _check_float_range_keys:
|
||||
try:
|
||||
range_keys = _check_float_range_keys[key]
|
||||
Validator.check_float_range(value, **range_keys)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if key in _check_port_keys:
|
||||
if value < 1 or value > 65535:
|
||||
raise ValueError("The range of %s must be 1 to 65535, but got %d." % (key, value))
|
||||
|
||||
|
||||
def _check_conflict_value(kwargs):
|
||||
if "upload_compress_type" in kwargs and " encrypt_type" in kwargs:
|
||||
if kwargs["upload_compress_type"] != "NO_COMPRESS" and kwargs["encrypt_type"] in ("SIGNDS", "PW_ENCRYPT"):
|
||||
logger.warning("The '{}' and '{}' are conflicted, and in '{}' mode the"
|
||||
" 'upload_compress_type' will be 'NO_COMPRESS'".format(kwargs["encrypt_type"],
|
||||
kwargs["upload_compress_type"],
|
||||
kwargs["encrypt_type"]))
|
||||
kwargs["upload_compress_type"] = "NO_COMPRESS"
|
||||
return kwargs
|
||||
|
|
|
@ -47,6 +47,16 @@ table FeatureMap{
|
|||
weight_fullname:string;
|
||||
data:[float];
|
||||
}
|
||||
|
||||
enum CompressType:byte {NO_COMPRESS = 0, DIFF_SPARSE_QUANT = 1, QUANT = 2}
|
||||
|
||||
table CompressFeatureMap{
|
||||
weight_fullname:string;
|
||||
compress_data:[int8];
|
||||
min_val:float;
|
||||
max_val:float;
|
||||
}
|
||||
|
||||
table RequestFLJob{
|
||||
fl_name:string;
|
||||
fl_id:string;
|
||||
|
@ -58,6 +68,7 @@ table RequestFLJob{
|
|||
equip_cert:string;
|
||||
equip_ca_cert:string;
|
||||
root_cert:string;
|
||||
download_compress_types:[CompressType];
|
||||
}
|
||||
table ResponseFLJob {
|
||||
retcode:int;
|
||||
|
@ -68,6 +79,10 @@ table ResponseFLJob {
|
|||
fl_plan_config:FLPlan;
|
||||
feature_map:[FeatureMap];
|
||||
timestamp:string;
|
||||
upload_compress_type:CompressType;
|
||||
upload_sparse_rate:float;
|
||||
download_compress_type:CompressType;
|
||||
compress_feature_map:[CompressFeatureMap];
|
||||
}
|
||||
|
||||
table FLPlan {
|
||||
|
@ -94,6 +109,10 @@ table RequestUpdateModel{
|
|||
upload_loss:float;
|
||||
sign:int;
|
||||
index_array:[int];
|
||||
compress_feature_map:[CompressFeatureMap];
|
||||
upload_compress_type:CompressType;
|
||||
upload_sparse_rate:float;
|
||||
name_vec:[string];
|
||||
}
|
||||
|
||||
table ResponseUpdateModel{
|
||||
|
@ -132,6 +151,7 @@ table RequestGetModel{
|
|||
fl_name:string;
|
||||
iteration:int;
|
||||
timestamp:string;
|
||||
download_compress_types:[CompressType];
|
||||
}
|
||||
table ResponseGetModel{
|
||||
retcode:int;
|
||||
|
@ -139,6 +159,8 @@ table ResponseGetModel{
|
|||
iteration:int;
|
||||
feature_map:[FeatureMap];
|
||||
timestamp:string;
|
||||
download_compress_type:CompressType;
|
||||
compress_feature_map:[CompressFeatureMap];
|
||||
}
|
||||
|
||||
table RequestAsyncGetModel{
|
||||
|
|
Loading…
Reference in New Issue