forked from mindspore-Ecosystem/mindspore
fl compression
This commit is contained in:
parent
0373d2f915
commit
ce17db99a6
|
@ -51,6 +51,8 @@ if(NOT ENABLE_CPU OR WIN32)
|
||||||
list(REMOVE_ITEM _FL_SRC_FILES "armour/cipher/cipher_reconstruct.cc")
|
list(REMOVE_ITEM _FL_SRC_FILES "armour/cipher/cipher_reconstruct.cc")
|
||||||
list(REMOVE_ITEM _FL_SRC_FILES "armour/cipher/cipher_shares.cc")
|
list(REMOVE_ITEM _FL_SRC_FILES "armour/cipher/cipher_shares.cc")
|
||||||
list(REMOVE_ITEM _FL_SRC_FILES "armour/cipher/cipher_unmask.cc")
|
list(REMOVE_ITEM _FL_SRC_FILES "armour/cipher/cipher_unmask.cc")
|
||||||
|
list(REMOVE_ITEM _FL_SRC_FILES "compression/decode_executor.cc")
|
||||||
|
list(REMOVE_ITEM _FL_SRC_FILES "compression/encode_executor.cc")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
if(CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
||||||
|
|
|
@ -0,0 +1,150 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "fl/compression/decode_executor.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace fl {
|
||||||
|
namespace compression {
|
||||||
|
std::vector<int> DecodeExecutor::ConstructMaskArray(int seed, float upload_sparse_rate, size_t param_num) {
|
||||||
|
static int multiplier = 2147483647;
|
||||||
|
static double increment = 4294967294.0;
|
||||||
|
static int modulo = 48271;
|
||||||
|
size_t retain_num = size_t(static_cast<float>(param_num) * upload_sparse_rate);
|
||||||
|
if (retain_num == 0) {
|
||||||
|
MS_LOG(WARNING) << "The retain_num is 0, and upload_sparse_rate is too small.";
|
||||||
|
}
|
||||||
|
std::vector<int> mask_array(param_num, 0);
|
||||||
|
for (size_t i = 0; i < retain_num; ++i) {
|
||||||
|
mask_array[i] = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
seed = ((seed + multiplier) * modulo) % multiplier;
|
||||||
|
for (size_t i = 0; i < param_num; ++i) {
|
||||||
|
// generate random number in (0, 1)
|
||||||
|
double rand = static_cast<double>(seed) / increment + 0.5;
|
||||||
|
// update seed
|
||||||
|
seed = (seed * modulo) % multiplier;
|
||||||
|
size_t j = size_t(rand * static_cast<double>(param_num - i)) + i;
|
||||||
|
int temp = mask_array[i];
|
||||||
|
mask_array[i] = mask_array[j];
|
||||||
|
mask_array[j] = temp;
|
||||||
|
}
|
||||||
|
return mask_array;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool DecodeExecutor::DeQuantSparseDiff(std::map<std::string, std::vector<float>> *weight_map,
|
||||||
|
const std::vector<CompressFeatureMap> &compress_feature_maps, size_t num_bits,
|
||||||
|
float upload_sparse_rate, int seed, const std::vector<std::string> &name_vec,
|
||||||
|
size_t data_size) {
|
||||||
|
std::vector<std::vector<float>> decompress_feature_maps;
|
||||||
|
|
||||||
|
// origin parameters
|
||||||
|
std::vector<size_t> shape_vec;
|
||||||
|
size_t param_num = 0;
|
||||||
|
const auto &iter_to_model = mindspore::fl::server::ModelStore::GetInstance().iteration_to_model();
|
||||||
|
size_t latest_iter_num = iter_to_model.rbegin()->first;
|
||||||
|
std::map<std::string, AddressPtr> feature_maps =
|
||||||
|
mindspore::fl::server::ModelStore::GetInstance().GetModelByIterNum(latest_iter_num);
|
||||||
|
// get shape vector and number of upload parameters
|
||||||
|
for (const auto &name : name_vec) {
|
||||||
|
size_t shape = feature_maps[name]->size / sizeof(float);
|
||||||
|
shape_vec.emplace_back(shape);
|
||||||
|
param_num += shape;
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "Compression get last weights success!";
|
||||||
|
|
||||||
|
// quant decode
|
||||||
|
auto temp1 = static_cast<float>(1 << num_bits) - 1.0f;
|
||||||
|
auto temp2 = static_cast<float>(1 << (num_bits - 1));
|
||||||
|
std::vector<float> de_min_max_feature_map;
|
||||||
|
for (auto compress_feature_map : compress_feature_maps) {
|
||||||
|
float min_val = compress_feature_map.min_val;
|
||||||
|
float max_val = compress_feature_map.max_val;
|
||||||
|
float scale_val = static_cast<float>(max_val - min_val) / temp1 + 1e-10f;
|
||||||
|
size_t size = compress_feature_map.compress_data.size();
|
||||||
|
for (size_t i = 0; i < size; ++i) {
|
||||||
|
de_min_max_feature_map.emplace_back(
|
||||||
|
(static_cast<float>(compress_feature_map.compress_data[i]) + temp2) * scale_val + min_val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "Compression quant decode success!";
|
||||||
|
|
||||||
|
// sparse decode
|
||||||
|
std::vector<int> mask_array = ConstructMaskArray(seed, upload_sparse_rate, param_num);
|
||||||
|
size_t index = 0;
|
||||||
|
size_t de_min_max_feature_map_index = 0;
|
||||||
|
for (const auto &shape : shape_vec) {
|
||||||
|
std::vector<float> feature_map(shape);
|
||||||
|
for (size_t i = 0; i < shape; ++i) {
|
||||||
|
if (index >= mask_array.size()) {
|
||||||
|
MS_LOG(WARNING) << "The mask_array and parameter shape is not matched.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (mask_array[index] == 1) {
|
||||||
|
if (de_min_max_feature_map_index >= de_min_max_feature_map.size()) {
|
||||||
|
MS_LOG(WARNING) << "The number of upload parameters is too small.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
feature_map[i] = de_min_max_feature_map[de_min_max_feature_map_index];
|
||||||
|
de_min_max_feature_map_index += 1;
|
||||||
|
} else {
|
||||||
|
feature_map[i] = 0.0f;
|
||||||
|
}
|
||||||
|
index += 1;
|
||||||
|
}
|
||||||
|
decompress_feature_maps.emplace_back(feature_map);
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "Compression sparse decode success!";
|
||||||
|
|
||||||
|
// difference decode
|
||||||
|
for (size_t i = 0; i < decompress_feature_maps.size(); ++i) {
|
||||||
|
size_t feature_size = decompress_feature_maps[i].size();
|
||||||
|
std::string name = name_vec[i];
|
||||||
|
float *weight_data = reinterpret_cast<float *>(feature_maps[name]->addr);
|
||||||
|
auto &weight_item = (*weight_map)[name];
|
||||||
|
weight_item.resize(feature_size);
|
||||||
|
for (size_t j = 0; j < feature_size; ++j) {
|
||||||
|
weight_item[j] = decompress_feature_maps[i][j] + data_size * weight_data[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "Compression difference decode success!";
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool DecodeExecutor::Decode(std::map<std::string, std::vector<float>> *weight_map,
|
||||||
|
const std::vector<CompressFeatureMap> &compress_feature_maps,
|
||||||
|
schema::CompressType upload_compress_type, float upload_sparse_rate, int seed,
|
||||||
|
const std::vector<std::string> &name_vec, size_t data_size) {
|
||||||
|
if (upload_compress_type == schema::CompressType_DIFF_SPARSE_QUANT) {
|
||||||
|
return DeQuantSparseDiff(weight_map, compress_feature_maps, 8, upload_sparse_rate, seed, name_vec, data_size);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
schema::CompressType DecodeExecutor::GetCompressType(schema::CompressType upload_compress_type) {
|
||||||
|
if (upload_compress_type == schema::CompressType_DIFF_SPARSE_QUANT) {
|
||||||
|
MS_LOG(DEBUG) << "This upload compress type is DIFF_SPARSE_QUANT.";
|
||||||
|
return schema::CompressType_DIFF_SPARSE_QUANT;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(DEBUG) << "This upload compress type is NO_COMPRESS.";
|
||||||
|
return schema::CompressType_NO_COMPRESS;
|
||||||
|
}
|
||||||
|
} // namespace compression
|
||||||
|
} // namespace fl
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,74 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_FL_COMPRESSION_DECODE_EXECUTOR_H_
|
||||||
|
#define MINDSPORE_CCSRC_FL_COMPRESSION_DECODE_EXECUTOR_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <cstring>
|
||||||
|
#include <functional>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <regex>
|
||||||
|
#include <map>
|
||||||
|
#include <utility>
|
||||||
|
#include "proto/comm.pb.h"
|
||||||
|
#include "schema/fl_job_generated.h"
|
||||||
|
#include "schema/cipher_generated.h"
|
||||||
|
#include "fl/server/model_store.h"
|
||||||
|
#include "fl/server/common.h"
|
||||||
|
#include "ps/ps_context.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace fl {
|
||||||
|
namespace compression {
|
||||||
|
struct CompressFeatureMap {
|
||||||
|
std::string weight_fullname;
|
||||||
|
std::vector<int8_t> compress_data;
|
||||||
|
float min_val;
|
||||||
|
float max_val;
|
||||||
|
};
|
||||||
|
|
||||||
|
class DecodeExecutor {
|
||||||
|
public:
|
||||||
|
static DecodeExecutor &GetInstance() {
|
||||||
|
static DecodeExecutor instance;
|
||||||
|
return instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
// construct mask array for random sparse
|
||||||
|
std::vector<int> ConstructMaskArray(int seed, float upload_sparse_rate, size_t param_num);
|
||||||
|
|
||||||
|
// decode min_max quantization and random sparse and parameter difference
|
||||||
|
bool DeQuantSparseDiff(std::map<std::string, std::vector<float>> *weight_map,
|
||||||
|
const std::vector<CompressFeatureMap> &compress_feature_maps, size_t num_bits,
|
||||||
|
float upload_sparse_rate, int seed, const std::vector<std::string> &name_vec,
|
||||||
|
size_t data_size);
|
||||||
|
|
||||||
|
// decode
|
||||||
|
bool Decode(std::map<std::string, std::vector<float>> *weight_map,
|
||||||
|
const std::vector<CompressFeatureMap> &compress_feature_maps, schema::CompressType upload_compress_type,
|
||||||
|
float upload_sparse_rate, int seed, const std::vector<std::string> &name_vec, size_t data_size);
|
||||||
|
|
||||||
|
schema::CompressType GetCompressType(schema::CompressType upload_compress_type);
|
||||||
|
};
|
||||||
|
} // namespace compression
|
||||||
|
} // namespace fl
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_FL_COMPRESSION_DECODE_EXECUTOR_H_
|
|
@ -0,0 +1,102 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "fl/compression/encode_executor.h"
|
||||||
|
|
||||||
|
#include <arpa/inet.h>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <cstring>
|
||||||
|
#include <functional>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <regex>
|
||||||
|
#include <map>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
#include "fl/server/common.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace fl {
|
||||||
|
namespace compression {
|
||||||
|
bool CompressExecutor::EnableCompressWeight(const schema::CompressType compressType) {
|
||||||
|
return kCompressTypeMap.count(compressType) > 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CompressExecutor::construct_compress_weight(std::map<std::string, CompressWeight> *compressWeights,
|
||||||
|
std::map<std::string, std::vector<float>> feature_maps,
|
||||||
|
const schema::CompressType compressType) {
|
||||||
|
if (compressType == schema::CompressType_QUANT) {
|
||||||
|
return quant_min_max(compressWeights, feature_maps, kCompressTypeMap.at(compressType));
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CompressExecutor::quant_min_max(std::map<std::string, CompressWeight> *compressWeights,
|
||||||
|
std::map<std::string, std::vector<float>> feature_maps, size_t num_bits) {
|
||||||
|
auto temp1 = static_cast<float>(1 << num_bits) - 1.0f;
|
||||||
|
auto temp2 = static_cast<float>(1 << (num_bits - 1));
|
||||||
|
for (const auto &feature_map : feature_maps) {
|
||||||
|
std::string weight_name = feature_map.first;
|
||||||
|
float min_value = 1e10f;
|
||||||
|
float max_value = -min_value;
|
||||||
|
for (const auto &feature : feature_map.second) {
|
||||||
|
if (feature > max_value) {
|
||||||
|
max_value = feature;
|
||||||
|
}
|
||||||
|
if (feature < min_value) {
|
||||||
|
min_value = feature;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
float scale_value = (max_value - min_value) / temp1 + 1e-10f;
|
||||||
|
size_t size = feature_map.second.size();
|
||||||
|
if (size == 0) {
|
||||||
|
MS_LOG(WARNING) << "The size of parameters is zero.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
CompressWeight compressWeight;
|
||||||
|
for (size_t i = 0; i < size; ++i) {
|
||||||
|
auto round_data = round((feature_map.second[i] - min_value) / scale_value - temp2);
|
||||||
|
// bit pack can be implemented here in the future
|
||||||
|
auto int8_data = int8_t(round_data);
|
||||||
|
compressWeight.compress_data.emplace_back(int8_data);
|
||||||
|
}
|
||||||
|
compressWeight.min_val = min_value;
|
||||||
|
compressWeight.max_val = max_value;
|
||||||
|
compressWeight.compress_data_len = size;
|
||||||
|
|
||||||
|
(*compressWeights)[weight_name] = compressWeight;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
schema::CompressType CompressExecutor::GetCompressType(const flatbuffers::Vector<int8_t> *download_compress_types) {
|
||||||
|
schema::CompressType compressType = schema::CompressType_NO_COMPRESS;
|
||||||
|
if (download_compress_types == nullptr) {
|
||||||
|
MS_LOG(DEBUG) << "The client does not support current download compress type.";
|
||||||
|
} else {
|
||||||
|
for (size_t i = 0; i < download_compress_types->size(); ++i) {
|
||||||
|
auto download_compress_type = download_compress_types->Get(i);
|
||||||
|
if (download_compress_type == schema::CompressType_QUANT) {
|
||||||
|
compressType = schema::CompressType_QUANT;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return compressType;
|
||||||
|
}
|
||||||
|
} // namespace compression
|
||||||
|
} // namespace fl
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,68 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_FL_COMPRESSION_ENCODE_EXECUTOR_H_
|
||||||
|
#define MINDSPORE_CCSRC_FL_COMPRESSION_ENCODE_EXECUTOR_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
|
#include "proto/comm.pb.h"
|
||||||
|
#include "schema/fl_job_generated.h"
|
||||||
|
#include "schema/cipher_generated.h"
|
||||||
|
#include "fl/armour/secure_protocol/key_agreement.h"
|
||||||
|
#include "ps/ps_context.h"
|
||||||
|
#include "ps/core/worker_node.h"
|
||||||
|
#include "ps/core/cluster_metadata.h"
|
||||||
|
#include "ps/core/communicator/tcp_communicator.h"
|
||||||
|
#include "fl/server/common.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace fl {
|
||||||
|
namespace compression {
|
||||||
|
// compress type map: schema::CompressType -> num bits
|
||||||
|
const std::map<schema::CompressType, size_t> kCompressTypeMap = {{schema::CompressType_QUANT, 8}};
|
||||||
|
|
||||||
|
struct CompressWeight {
|
||||||
|
std::vector<int8_t> compress_data;
|
||||||
|
size_t compress_data_len;
|
||||||
|
float min_val;
|
||||||
|
float max_val;
|
||||||
|
};
|
||||||
|
|
||||||
|
class CompressExecutor {
|
||||||
|
public:
|
||||||
|
static CompressExecutor &GetInstance() {
|
||||||
|
static CompressExecutor instance;
|
||||||
|
return instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool EnableCompressWeight(const schema::CompressType compressType);
|
||||||
|
|
||||||
|
bool construct_compress_weight(std::map<std::string, CompressWeight> *compressWeights,
|
||||||
|
std::map<std::string, std::vector<float>> feature_maps,
|
||||||
|
const schema::CompressType compressType);
|
||||||
|
|
||||||
|
bool quant_min_max(std::map<std::string, CompressWeight> *compressWeights,
|
||||||
|
std::map<std::string, std::vector<float>> feature_maps, size_t num_bits);
|
||||||
|
|
||||||
|
schema::CompressType GetCompressType(const flatbuffers::Vector<int8_t> *download_compress_types);
|
||||||
|
};
|
||||||
|
} // namespace compression
|
||||||
|
} // namespace fl
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_FL_COMPRESSION_ENCODE_EXECUTOR_H_
|
|
@ -149,6 +149,11 @@ constexpr auto kUpdateModelRejectClientNum = "updateModelRejectClientNum";
|
||||||
constexpr auto kGetModelTotalClientNum = "getModelTotalClientNum";
|
constexpr auto kGetModelTotalClientNum = "getModelTotalClientNum";
|
||||||
constexpr auto kGetModelAcceptClientNum = "getModelAcceptClientNum";
|
constexpr auto kGetModelAcceptClientNum = "getModelAcceptClientNum";
|
||||||
constexpr auto kGetModelRejectClientNum = "getModelRejectClientNum";
|
constexpr auto kGetModelRejectClientNum = "getModelRejectClientNum";
|
||||||
|
constexpr auto kMinVal = "min_val";
|
||||||
|
constexpr auto kMaxVal = "max_val";
|
||||||
|
constexpr auto kQuant = "QUANT";
|
||||||
|
constexpr auto kDiffSparseQuant = "DIFF_SPARSE_QUANT";
|
||||||
|
constexpr auto kNoCompress = "NO_COMPRESS";
|
||||||
|
|
||||||
// OptimParamNameToIndex represents every inputs/workspace/outputs parameter's offset when an optimizer kernel is
|
// OptimParamNameToIndex represents every inputs/workspace/outputs parameter's offset when an optimizer kernel is
|
||||||
// launched.
|
// launched.
|
||||||
|
|
|
@ -588,6 +588,7 @@ void Iteration::Next(bool is_iteration_valid, const std::string &reason) {
|
||||||
|
|
||||||
if (LocalMetaStore::GetInstance().verifyAggregationFeatureMap(feature_map)) {
|
if (LocalMetaStore::GetInstance().verifyAggregationFeatureMap(feature_map)) {
|
||||||
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
|
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
|
||||||
|
ModelStore::GetInstance().StoreCompressModelByIterNum(iteration_num_, model);
|
||||||
iteration_result_ = IterationResult::kSuccess;
|
iteration_result_ = IterationResult::kSuccess;
|
||||||
MS_LOG(INFO) << "Iteration " << iteration_num_ << " is successfully finished.";
|
MS_LOG(INFO) << "Iteration " << iteration_num_ << " is successfully finished.";
|
||||||
} else {
|
} else {
|
||||||
|
@ -599,6 +600,7 @@ void Iteration::Next(bool is_iteration_valid, const std::string &reason) {
|
||||||
size_t latest_iter_num = iter_to_model.rbegin()->first;
|
size_t latest_iter_num = iter_to_model.rbegin()->first;
|
||||||
const auto &model = ModelStore::GetInstance().GetModelByIterNum(latest_iter_num);
|
const auto &model = ModelStore::GetInstance().GetModelByIterNum(latest_iter_num);
|
||||||
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
|
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
|
||||||
|
ModelStore::GetInstance().StoreCompressModelByIterNum(iteration_num_, model);
|
||||||
iteration_result_ = IterationResult::kFail;
|
iteration_result_ = IterationResult::kFail;
|
||||||
MS_LOG(WARNING) << "Iteration " << iteration_num_ << " is invalid. Reason: " << reason;
|
MS_LOG(WARNING) << "Iteration " << iteration_num_ << " is invalid. Reason: " << reason;
|
||||||
}
|
}
|
||||||
|
|
|
@ -92,7 +92,7 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req,
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto next_req_time = LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp);
|
auto next_req_time = LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp);
|
||||||
std::map<std::string, AddressPtr> feature_maps;
|
std::map<std::string, AddressPtr> feature_maps = {};
|
||||||
size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num();
|
size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num();
|
||||||
size_t get_model_iter = IntToSize(get_model_req->iteration());
|
size_t get_model_iter = IntToSize(get_model_req->iteration());
|
||||||
const auto &iter_to_model = ModelStore::GetInstance().iteration_to_model();
|
const auto &iter_to_model = ModelStore::GetInstance().iteration_to_model();
|
||||||
|
@ -110,6 +110,7 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req,
|
||||||
SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize());
|
SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
IncreaseAcceptClientNum();
|
IncreaseAcceptClientNum();
|
||||||
auto real_get_model_iter = get_model_iter;
|
auto real_get_model_iter = get_model_iter;
|
||||||
if (iter_to_model.count(get_model_iter) == 0) {
|
if (iter_to_model.count(get_model_iter) == 0) {
|
||||||
|
@ -118,12 +119,37 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req,
|
||||||
<< " is invalid. Current iteration is " << std::to_string(current_iter);
|
<< " is invalid. Current iteration is " << std::to_string(current_iter);
|
||||||
real_get_model_iter = latest_iter_num;
|
real_get_model_iter = latest_iter_num;
|
||||||
}
|
}
|
||||||
auto cache = ModelStore::GetInstance().GetModelResponseCache(name_, current_iter, real_get_model_iter);
|
auto download_compress_types = get_model_req->download_compress_types();
|
||||||
|
schema::CompressType compressType =
|
||||||
|
mindspore::fl::compression::CompressExecutor::GetInstance().GetCompressType(download_compress_types);
|
||||||
|
std::string compress_type;
|
||||||
|
if (compressType == schema::CompressType_QUANT) {
|
||||||
|
compress_type = kQuant;
|
||||||
|
} else {
|
||||||
|
compress_type = kNoCompress;
|
||||||
|
}
|
||||||
|
auto cache = ModelStore::GetInstance().GetModelResponseCache(name_, current_iter, real_get_model_iter, compress_type);
|
||||||
if (cache == nullptr) {
|
if (cache == nullptr) {
|
||||||
|
// Only download compress weights if client support.
|
||||||
|
std::map<std::string, AddressPtr> compress_feature_maps = {};
|
||||||
|
if (compressType == schema::CompressType_NO_COMPRESS) {
|
||||||
feature_maps = ModelStore::GetInstance().GetModelByIterNum(real_get_model_iter);
|
feature_maps = ModelStore::GetInstance().GetModelByIterNum(real_get_model_iter);
|
||||||
|
} else {
|
||||||
|
auto compressExecutor = mindspore::fl::compression::CompressExecutor::GetInstance();
|
||||||
|
if (compressExecutor.EnableCompressWeight(compressType)) {
|
||||||
|
const auto &iter_to_compress_model = ModelStore::GetInstance().iteration_to_compress_model();
|
||||||
|
if (iter_to_compress_model.count(get_model_iter) == 0) {
|
||||||
|
MS_LOG(DEBUG) << "The iteration of GetCompressModel request " << std::to_string(get_model_iter)
|
||||||
|
<< " is invalid. Current iteration is " << std::to_string(current_iter);
|
||||||
|
compress_feature_maps = ModelStore::GetInstance().GetCompressModelByIterNum(latest_iter_num, compressType);
|
||||||
|
} else {
|
||||||
|
compress_feature_maps = ModelStore::GetInstance().GetCompressModelByIterNum(get_model_iter, compressType);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
BuildGetModelRsp(fbb, schema::ResponseCode_SUCCEED, "Get model for iteration " + std::to_string(get_model_iter),
|
BuildGetModelRsp(fbb, schema::ResponseCode_SUCCEED, "Get model for iteration " + std::to_string(get_model_iter),
|
||||||
current_iter, feature_maps, std::to_string(next_req_time));
|
current_iter, feature_maps, std::to_string(next_req_time), compressType, compress_feature_maps);
|
||||||
cache = ModelStore::GetInstance().StoreModelResponseCache(name_, current_iter, real_get_model_iter,
|
cache = ModelStore::GetInstance().StoreModelResponseCache(name_, current_iter, real_get_model_iter, compress_type,
|
||||||
fbb->GetBufferPointer(), fbb->GetSize());
|
fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
if (cache == nullptr) {
|
if (cache == nullptr) {
|
||||||
SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize());
|
SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
@ -131,7 +157,7 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
SendResponseMsgInference(message, cache->data(), cache->size(), ModelStore::GetInstance().RelModelResponseCache);
|
SendResponseMsgInference(message, cache->data(), cache->size(), ModelStore::GetInstance().RelModelResponseCache);
|
||||||
MS_LOG(DEBUG) << "GetModel last iteratin is valid or not: " << Iteration::GetInstance().is_last_iteration_valid()
|
MS_LOG(DEBUG) << "GetModel last iteration is valid or not: " << Iteration::GetInstance().is_last_iteration_valid()
|
||||||
<< ", next request time is " << next_req_time << ", current iteration is " << current_iter;
|
<< ", next request time is " << next_req_time << ", current iteration is " << current_iter;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -139,7 +165,8 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req,
|
||||||
void GetModelKernel::BuildGetModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
|
void GetModelKernel::BuildGetModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
|
||||||
const std::string &reason, const size_t iter,
|
const std::string &reason, const size_t iter,
|
||||||
const std::map<std::string, AddressPtr> &feature_maps,
|
const std::map<std::string, AddressPtr> &feature_maps,
|
||||||
const std::string ×tamp) {
|
const std::string ×tamp, const schema::CompressType &compressType,
|
||||||
|
const std::map<std::string, AddressPtr> &compress_feature_maps) {
|
||||||
if (fbb == nullptr) {
|
if (fbb == nullptr) {
|
||||||
MS_LOG(ERROR) << "Input fbb is nullptr.";
|
MS_LOG(ERROR) << "Input fbb is nullptr.";
|
||||||
return;
|
return;
|
||||||
|
@ -156,12 +183,40 @@ void GetModelKernel::BuildGetModelRsp(const std::shared_ptr<FBBuilder> &fbb, con
|
||||||
}
|
}
|
||||||
auto fbs_feature_maps_vector = fbb->CreateVector(fbs_feature_maps);
|
auto fbs_feature_maps_vector = fbb->CreateVector(fbs_feature_maps);
|
||||||
|
|
||||||
|
// construct compress feature maps with fbs
|
||||||
|
std::vector<flatbuffers::Offset<schema::CompressFeatureMap>> fbs_compress_feature_maps;
|
||||||
|
for (const auto &compress_feature_map : compress_feature_maps) {
|
||||||
|
if (compress_feature_map.first.find(kMinVal) != string::npos ||
|
||||||
|
compress_feature_map.first.find(kMaxVal) != string::npos) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto fbs_compress_weight_fullname = fbb->CreateString(compress_feature_map.first);
|
||||||
|
auto fbs_compress_weight_data = fbb->CreateVector(reinterpret_cast<int8_t *>(compress_feature_map.second->addr),
|
||||||
|
compress_feature_map.second->size / sizeof(int8_t));
|
||||||
|
|
||||||
|
const std::string min_val_name = compress_feature_map.first + "." + kMinVal;
|
||||||
|
const std::string max_val_name = compress_feature_map.first + "." + kMaxVal;
|
||||||
|
|
||||||
|
const AddressPtr min_val_ptr = compress_feature_maps.at(min_val_name);
|
||||||
|
const AddressPtr max_val_ptr = compress_feature_maps.at(max_val_name);
|
||||||
|
|
||||||
|
float *fbs_min_val_ptr = reinterpret_cast<float *>(min_val_ptr->addr);
|
||||||
|
float *fbs_max_val_ptr = reinterpret_cast<float *>(max_val_ptr->addr);
|
||||||
|
auto fbs_compress_feature_map = schema::CreateCompressFeatureMap(
|
||||||
|
*(fbb.get()), fbs_compress_weight_fullname, fbs_compress_weight_data, *fbs_min_val_ptr, *fbs_max_val_ptr);
|
||||||
|
|
||||||
|
fbs_compress_feature_maps.push_back(fbs_compress_feature_map);
|
||||||
|
}
|
||||||
|
auto fbs_compress_feature_maps_vector = fbb->CreateVector(fbs_compress_feature_maps);
|
||||||
|
|
||||||
schema::ResponseGetModelBuilder rsp_get_model_builder(*(fbb.get()));
|
schema::ResponseGetModelBuilder rsp_get_model_builder(*(fbb.get()));
|
||||||
rsp_get_model_builder.add_retcode(static_cast<int>(retcode));
|
rsp_get_model_builder.add_retcode(static_cast<int>(retcode));
|
||||||
rsp_get_model_builder.add_reason(fbs_reason);
|
rsp_get_model_builder.add_reason(fbs_reason);
|
||||||
rsp_get_model_builder.add_iteration(static_cast<int>(iter));
|
rsp_get_model_builder.add_iteration(static_cast<int>(iter));
|
||||||
rsp_get_model_builder.add_feature_map(fbs_feature_maps_vector);
|
rsp_get_model_builder.add_feature_map(fbs_feature_maps_vector);
|
||||||
rsp_get_model_builder.add_timestamp(fbs_timestamp);
|
rsp_get_model_builder.add_timestamp(fbs_timestamp);
|
||||||
|
rsp_get_model_builder.add_download_compress_type(compressType);
|
||||||
|
rsp_get_model_builder.add_compress_feature_map(fbs_compress_feature_maps_vector);
|
||||||
auto rsp_get_model = rsp_get_model_builder.Finish();
|
auto rsp_get_model = rsp_get_model_builder.Finish();
|
||||||
fbb->Finish(rsp_get_model);
|
fbb->Finish(rsp_get_model);
|
||||||
return;
|
return;
|
||||||
|
|
|
@ -25,6 +25,7 @@
|
||||||
#include "fl/server/executor.h"
|
#include "fl/server/executor.h"
|
||||||
#include "fl/server/kernel/round/round_kernel.h"
|
#include "fl/server/kernel/round/round_kernel.h"
|
||||||
#include "fl/server/kernel/round/round_kernel_factory.h"
|
#include "fl/server/kernel/round/round_kernel_factory.h"
|
||||||
|
#include "fl/compression/encode_executor.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace fl {
|
namespace fl {
|
||||||
|
@ -44,7 +45,9 @@ class GetModelKernel : public RoundKernel {
|
||||||
void GetModel(const schema::RequestGetModel *get_model_req, const std::shared_ptr<ps::core::MessageHandler> &message);
|
void GetModel(const schema::RequestGetModel *get_model_req, const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||||
void BuildGetModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
|
void BuildGetModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
|
||||||
const std::string &reason, const size_t iter,
|
const std::string &reason, const size_t iter,
|
||||||
const std::map<std::string, AddressPtr> &feature_maps, const std::string ×tamp);
|
const std::map<std::string, AddressPtr> &feature_maps, const std::string ×tamp,
|
||||||
|
const schema::CompressType &compressType = schema::CompressType_NO_COMPRESS,
|
||||||
|
const std::map<std::string, AddressPtr> &compress_feature_maps = {});
|
||||||
|
|
||||||
// The executor is for getting model for getModel request.
|
// The executor is for getting model for getModel request.
|
||||||
Executor *executor_;
|
Executor *executor_;
|
||||||
|
|
|
@ -126,10 +126,19 @@ bool StartFLJobKernel::Launch(const uint8_t *req_data, size_t len,
|
||||||
IncreaseAcceptClientNum();
|
IncreaseAcceptClientNum();
|
||||||
auto curr_iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
auto curr_iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
||||||
auto last_iteration = curr_iter_num - 1;
|
auto last_iteration = curr_iter_num - 1;
|
||||||
auto cache = ModelStore::GetInstance().GetModelResponseCache(name_, curr_iter_num, last_iteration);
|
auto download_compress_types = start_fl_job_req->download_compress_types();
|
||||||
|
schema::CompressType compressType =
|
||||||
|
mindspore::fl::compression::CompressExecutor::GetInstance().GetCompressType(download_compress_types);
|
||||||
|
std::string compress_type;
|
||||||
|
if (compressType == schema::CompressType_QUANT) {
|
||||||
|
compress_type = kQuant;
|
||||||
|
} else {
|
||||||
|
compress_type = kNoCompress;
|
||||||
|
}
|
||||||
|
auto cache = ModelStore::GetInstance().GetModelResponseCache(name_, curr_iter_num, last_iteration, compress_type);
|
||||||
if (cache == nullptr) {
|
if (cache == nullptr) {
|
||||||
StartFLJob(fbb);
|
StartFLJob(fbb, device_meta, start_fl_job_req);
|
||||||
cache = ModelStore::GetInstance().StoreModelResponseCache(name_, curr_iter_num, last_iteration,
|
cache = ModelStore::GetInstance().StoreModelResponseCache(name_, curr_iter_num, last_iteration, compress_type,
|
||||||
fbb->GetBufferPointer(), fbb->GetSize());
|
fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
if (cache == nullptr) {
|
if (cache == nullptr) {
|
||||||
SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize());
|
SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
@ -303,22 +312,40 @@ ResultCode StartFLJobKernel::CountForStartFLJob(const std::shared_ptr<FBBuilder>
|
||||||
return ResultCode::kSuccess;
|
return ResultCode::kSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
void StartFLJobKernel::StartFLJob(const std::shared_ptr<FBBuilder> &fbb) {
|
void StartFLJobKernel::StartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &,
|
||||||
|
const schema::RequestFLJob *start_fl_job_req) {
|
||||||
size_t last_iteration = LocalMetaStore::GetInstance().curr_iter_num() - 1;
|
size_t last_iteration = LocalMetaStore::GetInstance().curr_iter_num() - 1;
|
||||||
auto feature_maps = ModelStore::GetInstance().GetModelByIterNum(last_iteration);
|
|
||||||
|
std::map<std::string, AddressPtr> feature_maps = {};
|
||||||
|
std::map<std::string, AddressPtr> compress_feature_maps = {};
|
||||||
|
|
||||||
|
// Only download compress weights if client support.
|
||||||
|
auto download_compress_types = start_fl_job_req->download_compress_types();
|
||||||
|
schema::CompressType compressType =
|
||||||
|
mindspore::fl::compression::CompressExecutor::GetInstance().GetCompressType(download_compress_types);
|
||||||
|
if (compressType == schema::CompressType_NO_COMPRESS) {
|
||||||
|
feature_maps = ModelStore::GetInstance().GetModelByIterNum(last_iteration);
|
||||||
if (feature_maps.empty()) {
|
if (feature_maps.empty()) {
|
||||||
MS_LOG(WARNING) << "The feature map for startFLJob is empty.";
|
MS_LOG(WARNING) << "The feature map for startFLJob is empty.";
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
if (mindspore::fl::compression::CompressExecutor::GetInstance().EnableCompressWeight(compressType)) {
|
||||||
|
compress_feature_maps = ModelStore::GetInstance().GetCompressModelByIterNum(last_iteration, compressType);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
BuildStartFLJobRsp(fbb, schema::ResponseCode_SUCCEED, "success", true,
|
BuildStartFLJobRsp(fbb, schema::ResponseCode_SUCCEED, "success", true,
|
||||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)),
|
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)),
|
||||||
feature_maps);
|
feature_maps, compressType, compress_feature_maps);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
|
void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
|
||||||
const std::string &reason, const bool is_selected,
|
const std::string &reason, const bool is_selected,
|
||||||
const std::string &next_req_time,
|
const std::string &next_req_time,
|
||||||
std::map<std::string, AddressPtr> feature_maps) {
|
const std::map<std::string, AddressPtr> &feature_maps,
|
||||||
|
const schema::CompressType &compressType,
|
||||||
|
const std::map<std::string, AddressPtr> &compress_feature_maps) {
|
||||||
if (fbb == nullptr) {
|
if (fbb == nullptr) {
|
||||||
MS_LOG(WARNING) << "Input fbb is nullptr.";
|
MS_LOG(WARNING) << "Input fbb is nullptr.";
|
||||||
return;
|
return;
|
||||||
|
@ -350,6 +377,12 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
|
||||||
auto cipher_public_params =
|
auto cipher_public_params =
|
||||||
schema::CreateCipherPublicParams(*fbb.get(), encrypt_type, pw_params, dp_params, ds_params);
|
schema::CreateCipherPublicParams(*fbb.get(), encrypt_type, pw_params, dp_params, ds_params);
|
||||||
#endif
|
#endif
|
||||||
|
schema::CompressType upload_compress_type;
|
||||||
|
if (ps::PSContext::instance()->upload_compress_type() == kDiffSparseQuant) {
|
||||||
|
upload_compress_type = schema::CompressType_DIFF_SPARSE_QUANT;
|
||||||
|
} else {
|
||||||
|
upload_compress_type = schema::CompressType_NO_COMPRESS;
|
||||||
|
}
|
||||||
|
|
||||||
schema::FLPlanBuilder fl_plan_builder(*(fbb.get()));
|
schema::FLPlanBuilder fl_plan_builder(*(fbb.get()));
|
||||||
fl_plan_builder.add_fl_name(fbs_fl_name);
|
fl_plan_builder.add_fl_name(fbs_fl_name);
|
||||||
|
@ -375,6 +408,33 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
|
||||||
}
|
}
|
||||||
auto fbs_feature_maps_vector = fbb->CreateVector(fbs_feature_maps);
|
auto fbs_feature_maps_vector = fbb->CreateVector(fbs_feature_maps);
|
||||||
|
|
||||||
|
// construct compress feature maps with fbs
|
||||||
|
std::vector<flatbuffers::Offset<schema::CompressFeatureMap>> fbs_compress_feature_maps;
|
||||||
|
for (const auto &compress_feature_map : compress_feature_maps) {
|
||||||
|
if (compressType == schema::CompressType_QUANT) {
|
||||||
|
if (compress_feature_map.first.find(kMinVal) != string::npos ||
|
||||||
|
compress_feature_map.first.find(kMaxVal) != string::npos) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto fbs_compress_weight_fullname = fbb->CreateString(compress_feature_map.first);
|
||||||
|
auto fbs_compress_weight_data = fbb->CreateVector(reinterpret_cast<int8_t *>(compress_feature_map.second->addr),
|
||||||
|
compress_feature_map.second->size / sizeof(int8_t));
|
||||||
|
|
||||||
|
const std::string min_val_name = compress_feature_map.first + "." + kMinVal;
|
||||||
|
const std::string max_val_name = compress_feature_map.first + "." + kMaxVal;
|
||||||
|
|
||||||
|
const AddressPtr min_val_ptr = compress_feature_maps.at(min_val_name);
|
||||||
|
const AddressPtr max_val_ptr = compress_feature_maps.at(max_val_name);
|
||||||
|
|
||||||
|
float *fbs_min_val_ptr = reinterpret_cast<float *>(min_val_ptr->addr);
|
||||||
|
float *fbs_max_val_ptr = reinterpret_cast<float *>(max_val_ptr->addr);
|
||||||
|
auto fbs_compress_feature_map = schema::CreateCompressFeatureMap(
|
||||||
|
*(fbb.get()), fbs_compress_weight_fullname, fbs_compress_weight_data, *fbs_min_val_ptr, *fbs_max_val_ptr);
|
||||||
|
fbs_compress_feature_maps.push_back(fbs_compress_feature_map);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto fbs_compress_feature_maps_vector = fbb->CreateVector(fbs_compress_feature_maps);
|
||||||
|
|
||||||
schema::ResponseFLJobBuilder rsp_fl_job_builder(*(fbb.get()));
|
schema::ResponseFLJobBuilder rsp_fl_job_builder(*(fbb.get()));
|
||||||
rsp_fl_job_builder.add_retcode(static_cast<int>(retcode));
|
rsp_fl_job_builder.add_retcode(static_cast<int>(retcode));
|
||||||
rsp_fl_job_builder.add_reason(fbs_reason);
|
rsp_fl_job_builder.add_reason(fbs_reason);
|
||||||
|
@ -383,6 +443,10 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
|
||||||
rsp_fl_job_builder.add_next_req_time(fbs_next_req_time);
|
rsp_fl_job_builder.add_next_req_time(fbs_next_req_time);
|
||||||
rsp_fl_job_builder.add_fl_plan_config(fbs_fl_plan);
|
rsp_fl_job_builder.add_fl_plan_config(fbs_fl_plan);
|
||||||
rsp_fl_job_builder.add_feature_map(fbs_feature_maps_vector);
|
rsp_fl_job_builder.add_feature_map(fbs_feature_maps_vector);
|
||||||
|
rsp_fl_job_builder.add_download_compress_type(compressType);
|
||||||
|
rsp_fl_job_builder.add_compress_feature_map(fbs_compress_feature_maps_vector);
|
||||||
|
rsp_fl_job_builder.add_upload_compress_type(upload_compress_type);
|
||||||
|
rsp_fl_job_builder.add_upload_sparse_rate(ps::PSContext::instance()->upload_sparse_rate());
|
||||||
auto rsp_fl_job = rsp_fl_job_builder.Finish();
|
auto rsp_fl_job = rsp_fl_job_builder.Finish();
|
||||||
fbb->Finish(rsp_fl_job);
|
fbb->Finish(rsp_fl_job);
|
||||||
return;
|
return;
|
||||||
|
|
|
@ -25,6 +25,9 @@
|
||||||
#include "fl/server/executor.h"
|
#include "fl/server/executor.h"
|
||||||
#include "fl/server/kernel/round/round_kernel.h"
|
#include "fl/server/kernel/round/round_kernel.h"
|
||||||
#include "fl/server/kernel/round/round_kernel_factory.h"
|
#include "fl/server/kernel/round/round_kernel_factory.h"
|
||||||
|
#include "schema/fl_job_generated.h"
|
||||||
|
#include "schema/cipher_generated.h"
|
||||||
|
#include "fl/compression/encode_executor.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace fl {
|
namespace fl {
|
||||||
|
@ -56,7 +59,8 @@ class StartFLJobKernel : public RoundKernel {
|
||||||
// Distributed count service counts for startFLJob.
|
// Distributed count service counts for startFLJob.
|
||||||
ResultCode CountForStartFLJob(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestFLJob *start_fl_job_req);
|
ResultCode CountForStartFLJob(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestFLJob *start_fl_job_req);
|
||||||
|
|
||||||
void StartFLJob(const std::shared_ptr<FBBuilder> &fbb);
|
void StartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta,
|
||||||
|
const schema::RequestFLJob *start_fl_job_req);
|
||||||
|
|
||||||
bool JudgeFLJobCert(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestFLJob *start_fl_job_req);
|
bool JudgeFLJobCert(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestFLJob *start_fl_job_req);
|
||||||
|
|
||||||
|
@ -65,7 +69,9 @@ class StartFLJobKernel : public RoundKernel {
|
||||||
// Build response for startFLJob round no matter success or failure.
|
// Build response for startFLJob round no matter success or failure.
|
||||||
void BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
|
void BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
|
||||||
const std::string &reason, const bool is_selected, const std::string &next_req_time,
|
const std::string &reason, const bool is_selected, const std::string &next_req_time,
|
||||||
std::map<std::string, AddressPtr> feature_maps = {});
|
const std::map<std::string, AddressPtr> &feature_maps = {},
|
||||||
|
const schema::CompressType &compressType = schema::CompressType_NO_COMPRESS,
|
||||||
|
const std::map<std::string, AddressPtr> &compress_feature_maps = {});
|
||||||
|
|
||||||
// The executor is for getting the initial model for startFLJob request.
|
// The executor is for getting the initial model for startFLJob request.
|
||||||
Executor *executor_;
|
Executor *executor_;
|
||||||
|
|
|
@ -201,6 +201,7 @@ ResultCode UpdateModelKernel::VerifyUpdateModel(const schema::RequestUpdateModel
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unordered_map<std::string, size_t> feature_map;
|
std::unordered_map<std::string, size_t> feature_map;
|
||||||
|
if (ps::PSContext::instance()->upload_compress_type() != kDiffSparseQuant) {
|
||||||
auto upload_feature_map = update_model_req->feature_map();
|
auto upload_feature_map = update_model_req->feature_map();
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(upload_feature_map, ResultCode::kFail);
|
MS_ERROR_IF_NULL_W_RET_VAL(upload_feature_map, ResultCode::kFail);
|
||||||
for (uint32_t i = 0; i < upload_feature_map->size(); i++) {
|
for (uint32_t i = 0; i < upload_feature_map->size(); i++) {
|
||||||
|
@ -213,11 +214,14 @@ ResultCode UpdateModelKernel::VerifyUpdateModel(const schema::RequestUpdateModel
|
||||||
size_t weight_size = item->data()->size() * sizeof(float);
|
size_t weight_size = item->data()->size() * sizeof(float);
|
||||||
feature_map[weight_full_name] = weight_size;
|
feature_map[weight_full_name] = weight_size;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bool verifyFeatureMapIsSuccess;
|
bool verifyFeatureMapIsSuccess;
|
||||||
if (ps::PSContext::instance()->encrypt_type() == ps::kDSEncryptType && update_model_req->sign() != 0) {
|
if (ps::PSContext::instance()->encrypt_type() == ps::kDSEncryptType && update_model_req->sign() != 0) {
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(update_model_req->index_array(), ResultCode::kFail);
|
MS_ERROR_IF_NULL_W_RET_VAL(update_model_req->index_array(), ResultCode::kFail);
|
||||||
verifyFeatureMapIsSuccess = VerifySignDSFeatureMap(feature_map, update_model_req);
|
verifyFeatureMapIsSuccess = VerifySignDSFeatureMap(feature_map, update_model_req);
|
||||||
|
} else if (ps::PSContext::instance()->upload_compress_type() == kDiffSparseQuant) {
|
||||||
|
verifyFeatureMapIsSuccess = VerifyUploadCompressFeatureMap(update_model_req);
|
||||||
} else {
|
} else {
|
||||||
verifyFeatureMapIsSuccess = LocalMetaStore::GetInstance().verifyAggregationFeatureMap(feature_map);
|
verifyFeatureMapIsSuccess = LocalMetaStore::GetInstance().verifyAggregationFeatureMap(feature_map);
|
||||||
}
|
}
|
||||||
|
@ -280,6 +284,45 @@ bool UpdateModelKernel::VerifySignDSFeatureMap(const std::unordered_map<std::str
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool UpdateModelKernel::VerifyUploadCompressFeatureMap(const schema::RequestUpdateModel *update_model_req) {
|
||||||
|
auto &aggregation_feature_map_ = LocalMetaStore::GetInstance().aggregation_feature_map();
|
||||||
|
auto upload_sparse_rate = update_model_req->upload_sparse_rate();
|
||||||
|
if (upload_sparse_rate != ps::PSContext::instance()->upload_sparse_rate()) {
|
||||||
|
MS_LOG(WARNING) << "The upload_sparse_rate must be equal to the setting in context.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto fbs_name_vec = update_model_req->name_vec();
|
||||||
|
if (fbs_name_vec == nullptr) {
|
||||||
|
MS_LOG(WARNING) << "The name_vec is null.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (fbs_name_vec->size() == 0) {
|
||||||
|
MS_LOG(WARNING) << "The size of name_vec must be larger than 0.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (fbs_name_vec->size() > aggregation_feature_map_.size()) {
|
||||||
|
MS_LOG(WARNING) << "The size of name_vec must be smaller than model in server.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < fbs_name_vec->size(); ++i) {
|
||||||
|
std::string name = fbs_name_vec->Get(i)->str();
|
||||||
|
if (aggregation_feature_map_.count(name) == 0) {
|
||||||
|
MS_LOG(WARNING) << "The upload name: " << name << " is not in model in server.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto fbs_compress_feature_map = update_model_req->compress_feature_map();
|
||||||
|
if (fbs_compress_feature_map == nullptr) {
|
||||||
|
MS_LOG(WARNING) << "The upload compress feature map is null.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (fbs_compress_feature_map->size() == 0) {
|
||||||
|
MS_LOG(WARNING) << "The upload compress feature map is empty.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_model_req,
|
ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_model_req,
|
||||||
const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta) {
|
const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta) {
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kFail);
|
MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kFail);
|
||||||
|
@ -292,6 +335,8 @@ ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *upda
|
||||||
std::map<std::string, UploadData> feature_map;
|
std::map<std::string, UploadData> feature_map;
|
||||||
if (ps::PSContext::instance()->encrypt_type() == ps::kDSEncryptType) {
|
if (ps::PSContext::instance()->encrypt_type() == ps::kDSEncryptType) {
|
||||||
feature_map = ParseSignDSFeatureMap(update_model_req, data_size, &weight_map);
|
feature_map = ParseSignDSFeatureMap(update_model_req, data_size, &weight_map);
|
||||||
|
} else if (ps::PSContext::instance()->upload_compress_type() == kDiffSparseQuant) {
|
||||||
|
feature_map = ParseUploadCompressFeatureMap(update_model_req, data_size, &weight_map);
|
||||||
} else {
|
} else {
|
||||||
feature_map = ParseFeatureMap(update_model_req);
|
feature_map = ParseFeatureMap(update_model_req);
|
||||||
}
|
}
|
||||||
|
@ -397,6 +442,89 @@ std::map<std::string, UploadData> UpdateModelKernel::ParseSignDSFeatureMap(
|
||||||
return feature_map;
|
return feature_map;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::map<std::string, UploadData> UpdateModelKernel::ParseUploadCompressFeatureMap(
|
||||||
|
const schema::RequestUpdateModel *update_model_req, size_t data_size,
|
||||||
|
std::map<std::string, std::vector<float>> *weight_map) {
|
||||||
|
MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, {});
|
||||||
|
std::map<std::string, UploadData> feature_map;
|
||||||
|
schema::CompressType upload_compress_type = update_model_req->upload_compress_type();
|
||||||
|
upload_compress_type =
|
||||||
|
mindspore::fl::compression::DecodeExecutor::GetInstance().GetCompressType(upload_compress_type);
|
||||||
|
MS_LOG(INFO) << "This schema upload compress type is: " << upload_compress_type;
|
||||||
|
if (upload_compress_type != schema::CompressType_NO_COMPRESS) {
|
||||||
|
MS_LOG(INFO) << "This upload compress type is DIFF_SPARSE_QUANT.";
|
||||||
|
feature_map = DecodeFeatureMap(weight_map, update_model_req, upload_compress_type, data_size);
|
||||||
|
return feature_map;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "This upload compress type is NO_COMPRESS.";
|
||||||
|
// Some clients upload origin weights.
|
||||||
|
auto fbs_feature_map = update_model_req->feature_map();
|
||||||
|
MS_ERROR_IF_NULL_W_RET_VAL(fbs_feature_map, feature_map);
|
||||||
|
for (uint32_t i = 0; i < fbs_feature_map->size(); i++) {
|
||||||
|
std::string weight_full_name = fbs_feature_map->Get(i)->weight_fullname()->str();
|
||||||
|
float *weight_data = const_cast<float *>(fbs_feature_map->Get(i)->data()->data());
|
||||||
|
size_t weight_size = fbs_feature_map->Get(i)->data()->size() * sizeof(float);
|
||||||
|
UploadData upload_data;
|
||||||
|
upload_data[kNewWeight].addr = weight_data;
|
||||||
|
upload_data[kNewWeight].size = weight_size;
|
||||||
|
feature_map[weight_full_name] = upload_data;
|
||||||
|
}
|
||||||
|
return feature_map;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::map<std::string, UploadData> UpdateModelKernel::DecodeFeatureMap(
|
||||||
|
std::map<std::string, std::vector<float>> *weight_map, const schema::RequestUpdateModel *update_model_req,
|
||||||
|
schema::CompressType upload_compress_type, size_t data_size) {
|
||||||
|
std::map<std::string, UploadData> feature_map;
|
||||||
|
|
||||||
|
// Get and set decode hyper parameters.
|
||||||
|
auto seed = update_model_req->iteration();
|
||||||
|
MS_LOG(INFO) << "The seed for compression is: " << seed;
|
||||||
|
auto upload_sparse_rate = update_model_req->upload_sparse_rate();
|
||||||
|
MS_LOG(INFO) << "The upload_sparse_rate for compression is: " << upload_sparse_rate;
|
||||||
|
// Get name vector.
|
||||||
|
auto fbs_name_vec = update_model_req->name_vec();
|
||||||
|
std::vector<std::string> name_vec;
|
||||||
|
for (size_t i = 0; i < fbs_name_vec->size(); ++i) {
|
||||||
|
name_vec.emplace_back(fbs_name_vec->Get(i)->str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parameter process for decode.
|
||||||
|
auto fbs_compress_feature_map = update_model_req->compress_feature_map();
|
||||||
|
std::vector<mindspore::fl::compression::CompressFeatureMap> compress_feature_maps;
|
||||||
|
for (size_t i = 0; i < fbs_compress_feature_map->size(); ++i) {
|
||||||
|
mindspore::fl::compression::CompressFeatureMap compress_feature_map;
|
||||||
|
int8_t *compress_weight_data = const_cast<int8_t *>(fbs_compress_feature_map->Get(i)->compress_data()->data());
|
||||||
|
size_t compress_weight_size = fbs_compress_feature_map->Get(i)->compress_data()->size();
|
||||||
|
MS_LOG(INFO) << "The compress weight size: " << compress_weight_size;
|
||||||
|
for (size_t j = 0; j < compress_weight_size; ++j) {
|
||||||
|
compress_feature_map.compress_data.emplace_back(compress_weight_data[j]);
|
||||||
|
}
|
||||||
|
compress_feature_map.min_val = fbs_compress_feature_map->Get(i)->min_val();
|
||||||
|
compress_feature_map.max_val = fbs_compress_feature_map->Get(i)->max_val();
|
||||||
|
MS_LOG(INFO) << "Min value: " << compress_feature_map.min_val;
|
||||||
|
MS_LOG(INFO) << "Max value: " << compress_feature_map.max_val;
|
||||||
|
compress_feature_maps.emplace_back(compress_feature_map);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode.
|
||||||
|
bool status = mindspore::fl::compression::DecodeExecutor::GetInstance().Decode(
|
||||||
|
weight_map, compress_feature_maps, upload_compress_type, upload_sparse_rate, seed, name_vec, data_size);
|
||||||
|
if (status) {
|
||||||
|
for (size_t i = 0; i < name_vec.size(); ++i) {
|
||||||
|
std::string weight_full_name = name_vec[i];
|
||||||
|
size_t weight_size = (*weight_map)[weight_full_name].size() * sizeof(float);
|
||||||
|
UploadData upload_data;
|
||||||
|
upload_data[kNewWeight].addr = (*weight_map)[weight_full_name].data();
|
||||||
|
upload_data[kNewWeight].size = weight_size;
|
||||||
|
feature_map[weight_full_name] = upload_data;
|
||||||
|
}
|
||||||
|
return feature_map;
|
||||||
|
}
|
||||||
|
MS_LOG(WARNING) << "Decode failed!";
|
||||||
|
return feature_map;
|
||||||
|
}
|
||||||
|
|
||||||
ResultCode UpdateModelKernel::CountForAggregation(const std::string &req_fl_id) {
|
ResultCode UpdateModelKernel::CountForAggregation(const std::string &req_fl_id) {
|
||||||
std::string count_reason = "";
|
std::string count_reason = "";
|
||||||
if (!DistributedCountService::GetInstance().Count(kCountForAggregation, req_fl_id, &count_reason)) {
|
if (!DistributedCountService::GetInstance().Count(kCountForAggregation, req_fl_id, &count_reason)) {
|
||||||
|
|
|
@ -30,6 +30,9 @@
|
||||||
#ifdef ENABLE_ARMOUR
|
#ifdef ENABLE_ARMOUR
|
||||||
#include "fl/armour/cipher/cipher_meta_storage.h"
|
#include "fl/armour/cipher/cipher_meta_storage.h"
|
||||||
#endif
|
#endif
|
||||||
|
#include "fl/compression/decode_executor.h"
|
||||||
|
#include "schema/fl_job_generated.h"
|
||||||
|
#include "schema/cipher_generated.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace fl {
|
namespace fl {
|
||||||
|
@ -64,8 +67,12 @@ class UpdateModelKernel : public RoundKernel {
|
||||||
std::map<std::string, UploadData> ParseSignDSFeatureMap(const schema::RequestUpdateModel *update_model_req,
|
std::map<std::string, UploadData> ParseSignDSFeatureMap(const schema::RequestUpdateModel *update_model_req,
|
||||||
size_t data_size,
|
size_t data_size,
|
||||||
std::map<std::string, std::vector<float>> *weight_map);
|
std::map<std::string, std::vector<float>> *weight_map);
|
||||||
|
std::map<std::string, UploadData> ParseUploadCompressFeatureMap(
|
||||||
|
const schema::RequestUpdateModel *update_model_req, size_t data_size,
|
||||||
|
std::map<std::string, std::vector<float>> *weight_map);
|
||||||
bool VerifySignDSFeatureMap(const std::unordered_map<std::string, size_t> &model,
|
bool VerifySignDSFeatureMap(const std::unordered_map<std::string, size_t> &model,
|
||||||
const schema::RequestUpdateModel *update_model_req);
|
const schema::RequestUpdateModel *update_model_req);
|
||||||
|
bool VerifyUploadCompressFeatureMap(const schema::RequestUpdateModel *update_model_req);
|
||||||
ResultCode CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb,
|
ResultCode CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb,
|
||||||
const schema::RequestUpdateModel *update_model_req);
|
const schema::RequestUpdateModel *update_model_req);
|
||||||
sigVerifyResult VerifySignature(const schema::RequestUpdateModel *update_model_req);
|
sigVerifyResult VerifySignature(const schema::RequestUpdateModel *update_model_req);
|
||||||
|
@ -78,6 +85,11 @@ class UpdateModelKernel : public RoundKernel {
|
||||||
|
|
||||||
// The time window of one iteration.
|
// The time window of one iteration.
|
||||||
size_t iteration_time_window_{0};
|
size_t iteration_time_window_{0};
|
||||||
|
|
||||||
|
// Decode functions of compression.
|
||||||
|
std::map<std::string, UploadData> DecodeFeatureMap(std::map<std::string, std::vector<float>> *weight_map,
|
||||||
|
const schema::RequestUpdateModel *update_model_req,
|
||||||
|
schema::CompressType upload_compress_type, size_t data_size);
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
|
|
|
@ -44,6 +44,11 @@ void MemoryRegister::StoreCharArray(std::unique_ptr<char[]> *array) {
|
||||||
MS_ERROR_IF_NULL_WO_RET_VAL(array);
|
MS_ERROR_IF_NULL_WO_RET_VAL(array);
|
||||||
char_arrays_.push_back(std::move(*array));
|
char_arrays_.push_back(std::move(*array));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void MemoryRegister::StoreFloat32(std::unique_ptr<float> *param) {
|
||||||
|
MS_ERROR_IF_NULL_WO_RET_VAL(param);
|
||||||
|
float_params_.push_back(std::move(*param));
|
||||||
|
}
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace fl
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -24,6 +24,7 @@
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <typeinfo>
|
#include <typeinfo>
|
||||||
#include "fl/server/common.h"
|
#include "fl/server/common.h"
|
||||||
|
#include "fl/compression/encode_executor.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace fl {
|
namespace fl {
|
||||||
|
@ -70,6 +71,25 @@ class MemoryRegister {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void RegisterParameter(const std::string &name, std::unique_ptr<T> *param, size_t size) {
|
||||||
|
MS_EXCEPTION_IF_NULL(param);
|
||||||
|
void *data = param->get();
|
||||||
|
AddressPtr addressPtr = std::make_shared<Address>();
|
||||||
|
addressPtr->addr = data;
|
||||||
|
addressPtr->size = size;
|
||||||
|
if (typeid(T) == typeid(float)) {
|
||||||
|
auto float_param = CastUniqueParamPtr<float, T>(param);
|
||||||
|
StoreFloat32(&float_param);
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << "MemoryRegister does not support type " << typeid(T).name();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
RegisterAddressPtr(name, addressPtr);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::map<std::string, AddressPtr> addresses_;
|
std::map<std::string, AddressPtr> addresses_;
|
||||||
std::vector<std::unique_ptr<float[]>> float_arrays_;
|
std::vector<std::unique_ptr<float[]>> float_arrays_;
|
||||||
|
@ -86,6 +106,15 @@ class MemoryRegister {
|
||||||
std::unique_ptr<T[]> CastUniquePtr(std::unique_ptr<S[]> *array) {
|
std::unique_ptr<T[]> CastUniquePtr(std::unique_ptr<S[]> *array) {
|
||||||
return std::unique_ptr<T[]>{reinterpret_cast<T *>(array->release())};
|
return std::unique_ptr<T[]>{reinterpret_cast<T *>(array->release())};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<std::unique_ptr<float>> float_params_;
|
||||||
|
|
||||||
|
void StoreFloat32(std::unique_ptr<float> *array);
|
||||||
|
|
||||||
|
template <typename T, typename S>
|
||||||
|
std::unique_ptr<T> CastUniqueParamPtr(std::unique_ptr<S> *param) {
|
||||||
|
return std::unique_ptr<T>{reinterpret_cast<T *>(param->release())};
|
||||||
|
}
|
||||||
};
|
};
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace fl
|
} // namespace fl
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "fl/server/executor.h"
|
#include "fl/server/executor.h"
|
||||||
|
#include "pipeline/jit/parse/parse.h"
|
||||||
#include "include/common/utils/python_adapter.h"
|
#include "include/common/utils/python_adapter.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -33,6 +34,10 @@ void ModelStore::Initialize(uint32_t rank_id, uint32_t max_count) {
|
||||||
max_model_count_ = max_count;
|
max_model_count_ = max_count;
|
||||||
initial_model_ = AssignNewModelMemory();
|
initial_model_ = AssignNewModelMemory();
|
||||||
iteration_to_model_[kInitIterationNum] = initial_model_;
|
iteration_to_model_[kInitIterationNum] = initial_model_;
|
||||||
|
std::map<std::string, AddressPtr> model = Executor::GetInstance().GetModel();
|
||||||
|
for (const auto &item : mindspore::fl::compression::kCompressTypeMap) {
|
||||||
|
iteration_to_compress_model_[kInitIterationNum][item.first] = AssignNewCompressModelMemory(item.first, model);
|
||||||
|
}
|
||||||
model_size_ = ComputeModelSize();
|
model_size_ = ComputeModelSize();
|
||||||
MS_LOG(INFO) << "Model store checkpoint dir is: " << ps::PSContext::instance()->checkpoint_dir();
|
MS_LOG(INFO) << "Model store checkpoint dir is: " << ps::PSContext::instance()->checkpoint_dir();
|
||||||
}
|
}
|
||||||
|
@ -101,6 +106,24 @@ std::map<std::string, AddressPtr> ModelStore::GetModelByIterNum(size_t iteration
|
||||||
return model;
|
return model;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::map<std::string, AddressPtr> ModelStore::GetCompressModelByIterNum(size_t iteration,
|
||||||
|
schema::CompressType compressType) {
|
||||||
|
std::unique_lock<std::mutex> lock(model_mtx_);
|
||||||
|
std::map<std::string, AddressPtr> compressModel = {};
|
||||||
|
if (iteration_to_compress_model_.count(iteration) == 0) {
|
||||||
|
MS_LOG(ERROR) << "Compress Model for iteration " << iteration << " is not stored.";
|
||||||
|
return compressModel;
|
||||||
|
}
|
||||||
|
std::map<schema::CompressType, std::shared_ptr<MemoryRegister>> compress_model_map =
|
||||||
|
iteration_to_compress_model_[iteration];
|
||||||
|
if (compress_model_map.count(compressType) == 0) {
|
||||||
|
MS_LOG(ERROR) << "Compress Model for compress type " << compressType << " is not stored.";
|
||||||
|
return compressModel;
|
||||||
|
}
|
||||||
|
compressModel = iteration_to_compress_model_[iteration][compressType]->addresses();
|
||||||
|
return compressModel;
|
||||||
|
}
|
||||||
|
|
||||||
void ModelStore::Reset() {
|
void ModelStore::Reset() {
|
||||||
std::unique_lock<std::mutex> lock(model_mtx_);
|
std::unique_lock<std::mutex> lock(model_mtx_);
|
||||||
initial_model_ = iteration_to_model_.rbegin()->second;
|
initial_model_ = iteration_to_model_.rbegin()->second;
|
||||||
|
@ -114,6 +137,11 @@ const std::map<size_t, std::shared_ptr<MemoryRegister>> &ModelStore::iteration_t
|
||||||
return iteration_to_model_;
|
return iteration_to_model_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const std::map<size_t, CompressTypeMap> &ModelStore::iteration_to_compress_model() {
|
||||||
|
std::unique_lock<std::mutex> lock(model_mtx_);
|
||||||
|
return iteration_to_compress_model_;
|
||||||
|
}
|
||||||
|
|
||||||
size_t ModelStore::model_size() const { return model_size_; }
|
size_t ModelStore::model_size() const { return model_size_; }
|
||||||
|
|
||||||
std::shared_ptr<MemoryRegister> ModelStore::AssignNewModelMemory() {
|
std::shared_ptr<MemoryRegister> ModelStore::AssignNewModelMemory() {
|
||||||
|
@ -146,6 +174,86 @@ std::shared_ptr<MemoryRegister> ModelStore::AssignNewModelMemory() {
|
||||||
return memory_register;
|
return memory_register;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<MemoryRegister> ModelStore::AssignNewCompressModelMemory(
|
||||||
|
schema::CompressType compressType, const std::map<std::string, AddressPtr> &model) {
|
||||||
|
if (model.empty()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Model feature map is empty.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
std::map<string, std::vector<float>> feature_maps;
|
||||||
|
for (auto &feature_map : model) {
|
||||||
|
auto weight_fullname = feature_map.first;
|
||||||
|
auto weight_data = reinterpret_cast<float *>(feature_map.second->addr);
|
||||||
|
std::vector<float> weight_data_vector{weight_data, weight_data + feature_map.second->size / sizeof(float)};
|
||||||
|
feature_maps[weight_fullname] = weight_data_vector;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::map<std::string, mindspore::fl::compression::CompressWeight> compressWeights;
|
||||||
|
bool status = mindspore::fl::compression::CompressExecutor::GetInstance().construct_compress_weight(
|
||||||
|
&compressWeights, feature_maps, compressType);
|
||||||
|
if (!status) {
|
||||||
|
MS_LOG(ERROR) << "Encode failed!";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assign new memory for the compress model.
|
||||||
|
std::shared_ptr<MemoryRegister> memory_register = std::make_shared<MemoryRegister>();
|
||||||
|
MS_ERROR_IF_NULL_W_RET_VAL(memory_register, nullptr);
|
||||||
|
MS_LOG(INFO) << "Register compressWeight for compressType: " << schema::EnumNameCompressType(compressType);
|
||||||
|
|
||||||
|
for (const auto &compressWeight : compressWeights) {
|
||||||
|
if (compressType == schema::CompressType_QUANT) {
|
||||||
|
std::string compress_weight_name = compressWeight.first;
|
||||||
|
std::string min_val_name = compress_weight_name + "." + kMinVal;
|
||||||
|
std::string max_val_name = compress_weight_name + "." + kMaxVal;
|
||||||
|
size_t compress_weight_size = compressWeight.second.compress_data_len * sizeof(int8_t);
|
||||||
|
auto compress_weight_data = std::make_unique<char[]>(compress_weight_size);
|
||||||
|
auto src_data_size = compress_weight_size;
|
||||||
|
auto dst_data_size = compress_weight_size;
|
||||||
|
int ret =
|
||||||
|
memcpy_s(compress_weight_data.get(), dst_data_size, compressWeight.second.compress_data.data(), src_data_size);
|
||||||
|
if (ret != 0) {
|
||||||
|
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
memory_register->RegisterArray(compress_weight_name, &compress_weight_data, compress_weight_size);
|
||||||
|
size_t float_size = 1;
|
||||||
|
auto min_val_ptr = std::make_unique<float>(compressWeight.second.min_val);
|
||||||
|
auto max_val_ptr = std::make_unique<float>(compressWeight.second.max_val);
|
||||||
|
|
||||||
|
memory_register->RegisterParameter(min_val_name, &min_val_ptr, float_size);
|
||||||
|
memory_register->RegisterParameter(max_val_name, &max_val_ptr, float_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return memory_register;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ModelStore::StoreCompressModelByIterNum(size_t iteration, const std::map<std::string, AddressPtr> &new_model) {
|
||||||
|
std::unique_lock<std::mutex> lock(model_mtx_);
|
||||||
|
if (iteration_to_compress_model_.count(iteration) != 0) {
|
||||||
|
MS_LOG(WARNING) << "Compress Model for iteration " << iteration << " is already stored";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (new_model.empty()) {
|
||||||
|
MS_LOG(ERROR) << "Compress Model feature map is empty.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
iteration_to_compress_model_[iteration] = {};
|
||||||
|
if (iteration_to_compress_model_.size() >= max_model_count_) {
|
||||||
|
auto compress_model_map = iteration_to_compress_model_.begin()->second;
|
||||||
|
compress_model_map.clear();
|
||||||
|
(void)iteration_to_compress_model_.erase(iteration_to_compress_model_.begin());
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto &item : mindspore::fl::compression::kCompressTypeMap) {
|
||||||
|
auto memory_register = AssignNewCompressModelMemory(item.first, new_model);
|
||||||
|
MS_ERROR_IF_NULL_WO_RET_VAL(memory_register);
|
||||||
|
iteration_to_compress_model_[iteration][item.first] = memory_register;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
size_t ModelStore::ComputeModelSize() {
|
size_t ModelStore::ComputeModelSize() {
|
||||||
std::unique_lock<std::mutex> lock(model_mtx_);
|
std::unique_lock<std::mutex> lock(model_mtx_);
|
||||||
if (iteration_to_model_.empty()) {
|
if (iteration_to_model_.empty()) {
|
||||||
|
@ -179,12 +287,14 @@ void ModelStore::RelModelResponseCache(const void *data, size_t datalen, void *e
|
||||||
|
|
||||||
std::shared_ptr<std::vector<uint8_t>> ModelStore::GetModelResponseCache(const string &round_name,
|
std::shared_ptr<std::vector<uint8_t>> ModelStore::GetModelResponseCache(const string &round_name,
|
||||||
size_t cur_iteration_num,
|
size_t cur_iteration_num,
|
||||||
size_t model_iteration_num) {
|
size_t model_iteration_num,
|
||||||
|
const std::string &compress_type) {
|
||||||
std::unique_lock<std::mutex> lock(model_response_cache_lock_);
|
std::unique_lock<std::mutex> lock(model_response_cache_lock_);
|
||||||
auto it = std::find_if(model_response_cache_.begin(), model_response_cache_.end(),
|
auto it = std::find_if(
|
||||||
[&round_name, cur_iteration_num, model_iteration_num](const HttpResponseModelCache &item) {
|
model_response_cache_.begin(), model_response_cache_.end(),
|
||||||
|
[&round_name, cur_iteration_num, model_iteration_num, &compress_type](const HttpResponseModelCache &item) {
|
||||||
return item.round_name == round_name && item.cur_iteration_num == cur_iteration_num &&
|
return item.round_name == round_name && item.cur_iteration_num == cur_iteration_num &&
|
||||||
item.model_iteration_num == model_iteration_num;
|
item.model_iteration_num == model_iteration_num && item.compress_type == compress_type;
|
||||||
});
|
});
|
||||||
if (it == model_response_cache_.end()) {
|
if (it == model_response_cache_.end()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -196,13 +306,15 @@ std::shared_ptr<std::vector<uint8_t>> ModelStore::GetModelResponseCache(const st
|
||||||
|
|
||||||
std::shared_ptr<std::vector<uint8_t>> ModelStore::StoreModelResponseCache(const string &round_name,
|
std::shared_ptr<std::vector<uint8_t>> ModelStore::StoreModelResponseCache(const string &round_name,
|
||||||
size_t cur_iteration_num,
|
size_t cur_iteration_num,
|
||||||
size_t model_iteration_num, const void *data,
|
size_t model_iteration_num,
|
||||||
size_t datalen) {
|
const std::string &compress_type,
|
||||||
|
const void *data, size_t datalen) {
|
||||||
std::unique_lock<std::mutex> lock(model_response_cache_lock_);
|
std::unique_lock<std::mutex> lock(model_response_cache_lock_);
|
||||||
auto it = std::find_if(model_response_cache_.begin(), model_response_cache_.end(),
|
auto it = std::find_if(
|
||||||
[&round_name, cur_iteration_num, model_iteration_num](const HttpResponseModelCache &item) {
|
model_response_cache_.begin(), model_response_cache_.end(),
|
||||||
|
[&round_name, cur_iteration_num, model_iteration_num, &compress_type](const HttpResponseModelCache &item) {
|
||||||
return item.round_name == round_name && item.cur_iteration_num == cur_iteration_num &&
|
return item.round_name == round_name && item.cur_iteration_num == cur_iteration_num &&
|
||||||
item.model_iteration_num == model_iteration_num;
|
item.model_iteration_num == model_iteration_num && item.compress_type == compress_type;
|
||||||
});
|
});
|
||||||
if (it != model_response_cache_.end()) {
|
if (it != model_response_cache_.end()) {
|
||||||
it->reference_count += 1;
|
it->reference_count += 1;
|
||||||
|
@ -223,6 +335,7 @@ std::shared_ptr<std::vector<uint8_t>> ModelStore::StoreModelResponseCache(const
|
||||||
item.round_name = round_name;
|
item.round_name = round_name;
|
||||||
item.cur_iteration_num = cur_iteration_num;
|
item.cur_iteration_num = cur_iteration_num;
|
||||||
item.model_iteration_num = model_iteration_num;
|
item.model_iteration_num = model_iteration_num;
|
||||||
|
item.compress_type = compress_type;
|
||||||
item.cache = cache;
|
item.cache = cache;
|
||||||
item.reference_count = 1;
|
item.reference_count = 1;
|
||||||
total_add_reference_count += 1;
|
total_add_reference_count += 1;
|
||||||
|
|
|
@ -25,6 +25,7 @@
|
||||||
#include "fl/server/common.h"
|
#include "fl/server/common.h"
|
||||||
#include "fl/server/memory_register.h"
|
#include "fl/server/memory_register.h"
|
||||||
#include "fl/server/executor.h"
|
#include "fl/server/executor.h"
|
||||||
|
#include "fl/compression/encode_executor.h"
|
||||||
#include "fl/server/local_meta_store.h"
|
#include "fl/server/local_meta_store.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -36,6 +37,9 @@ constexpr size_t kInitIterationNum = 0;
|
||||||
// The initial iteration number after ModelStore is reset.
|
// The initial iteration number after ModelStore is reset.
|
||||||
constexpr size_t kResetInitialIterNum = 1;
|
constexpr size_t kResetInitialIterNum = 1;
|
||||||
|
|
||||||
|
// The compress type map.
|
||||||
|
using CompressTypeMap = std::map<schema::CompressType, std::shared_ptr<MemoryRegister>>;
|
||||||
|
|
||||||
// Server framework use ModelStore to store and query models.
|
// Server framework use ModelStore to store and query models.
|
||||||
// ModelStore stores multiple models because worker could get models of the previous iterations.
|
// ModelStore stores multiple models because worker could get models of the previous iterations.
|
||||||
class ModelStore {
|
class ModelStore {
|
||||||
|
@ -64,15 +68,25 @@ class ModelStore {
|
||||||
// Returns the model size, which could be calculated at the initializing phase.
|
// Returns the model size, which could be calculated at the initializing phase.
|
||||||
size_t model_size() const;
|
size_t model_size() const;
|
||||||
|
|
||||||
|
// Get compress model of the given iteration.
|
||||||
|
std::map<std::string, AddressPtr> GetCompressModelByIterNum(size_t iteration, schema::CompressType compressType);
|
||||||
|
|
||||||
|
const std::map<size_t, std::map<schema::CompressType, std::shared_ptr<MemoryRegister>>>
|
||||||
|
&iteration_to_compress_model();
|
||||||
|
|
||||||
|
void StoreCompressModelByIterNum(size_t iteration, const std::map<std::string, AddressPtr> &new_model);
|
||||||
|
|
||||||
static void RelModelResponseCache(const void *data, size_t datalen, void *extra);
|
static void RelModelResponseCache(const void *data, size_t datalen, void *extra);
|
||||||
std::shared_ptr<std::vector<uint8_t>> GetModelResponseCache(const std::string &round_name, size_t cur_iteration_num,
|
std::shared_ptr<std::vector<uint8_t>> GetModelResponseCache(const std::string &round_name, size_t cur_iteration_num,
|
||||||
size_t model_iteration_num);
|
size_t model_iteration_num,
|
||||||
|
const std::string &compress_type);
|
||||||
std::shared_ptr<std::vector<uint8_t>> StoreModelResponseCache(const std::string &round_name, size_t cur_iteration_num,
|
std::shared_ptr<std::vector<uint8_t>> StoreModelResponseCache(const std::string &round_name, size_t cur_iteration_num,
|
||||||
size_t model_iteration_num, const void *data,
|
size_t model_iteration_num,
|
||||||
|
const std::string &compress_type, const void *data,
|
||||||
size_t datalen);
|
size_t datalen);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ModelStore() : max_model_count_(0), model_size_(0), iteration_to_model_({}) {}
|
ModelStore() : max_model_count_(0), model_size_(0), iteration_to_model_({}), iteration_to_compress_model_({}) {}
|
||||||
~ModelStore() = default;
|
~ModelStore() = default;
|
||||||
ModelStore(const ModelStore &) = delete;
|
ModelStore(const ModelStore &) = delete;
|
||||||
ModelStore &operator=(const ModelStore &) = delete;
|
ModelStore &operator=(const ModelStore &) = delete;
|
||||||
|
@ -83,6 +97,9 @@ class ModelStore {
|
||||||
// model_size_.
|
// model_size_.
|
||||||
std::shared_ptr<MemoryRegister> AssignNewModelMemory();
|
std::shared_ptr<MemoryRegister> AssignNewModelMemory();
|
||||||
|
|
||||||
|
std::shared_ptr<MemoryRegister> AssignNewCompressModelMemory(schema::CompressType compressType,
|
||||||
|
const std::map<std::string, AddressPtr> &model);
|
||||||
|
|
||||||
// Calculate the model size. This method should be called after iteration_to_model_ is initialized.
|
// Calculate the model size. This method should be called after iteration_to_model_ is initialized.
|
||||||
size_t ComputeModelSize();
|
size_t ComputeModelSize();
|
||||||
|
|
||||||
|
@ -95,12 +112,17 @@ class ModelStore {
|
||||||
// The number of all models stored is max_model_count_.
|
// The number of all models stored is max_model_count_.
|
||||||
std::mutex model_mtx_;
|
std::mutex model_mtx_;
|
||||||
std::map<size_t, std::shared_ptr<MemoryRegister>> iteration_to_model_;
|
std::map<size_t, std::shared_ptr<MemoryRegister>> iteration_to_model_;
|
||||||
|
|
||||||
|
// iteration -> (compress type -> compress model)
|
||||||
|
std::map<size_t, std::map<schema::CompressType, std::shared_ptr<MemoryRegister>>> iteration_to_compress_model_;
|
||||||
|
|
||||||
uint32_t rank_id_;
|
uint32_t rank_id_;
|
||||||
|
|
||||||
struct HttpResponseModelCache {
|
struct HttpResponseModelCache {
|
||||||
std::string round_name; // startFlJob, getModel
|
std::string round_name; // startFlJob, getModel
|
||||||
size_t cur_iteration_num = 0;
|
size_t cur_iteration_num = 0;
|
||||||
size_t model_iteration_num = 0;
|
size_t model_iteration_num = 0;
|
||||||
|
std::string compress_type = kNoCompress;
|
||||||
size_t reference_count = 0;
|
size_t reference_count = 0;
|
||||||
std::shared_ptr<std::vector<uint8_t>> cache = nullptr;
|
std::shared_ptr<std::vector<uint8_t>> cache = nullptr;
|
||||||
};
|
};
|
||||||
|
|
|
@ -507,6 +507,12 @@ PYBIND11_MODULE(_c_expression, m) {
|
||||||
.def("set_global_iteration_time_window", &PSContext::set_global_iteration_time_window,
|
.def("set_global_iteration_time_window", &PSContext::set_global_iteration_time_window,
|
||||||
"Set global iteration time window.")
|
"Set global iteration time window.")
|
||||||
.def("global_iteration_time_window", &PSContext::global_iteration_time_window, "Get global iteration time window.")
|
.def("global_iteration_time_window", &PSContext::global_iteration_time_window, "Get global iteration time window.")
|
||||||
|
.def("set_upload_compress_type", &PSContext::set_upload_compress_type, "Set upload compress type.")
|
||||||
|
.def("upload_compress_type", &PSContext::upload_compress_type, "Get upload compress type.")
|
||||||
|
.def("set_upload_sparse_rate", &PSContext::set_upload_sparse_rate, "Set upload sparse rate.")
|
||||||
|
.def("upload_sparse_rate", &PSContext::upload_sparse_rate, "Get upload sparse rate.")
|
||||||
|
.def("set_download_compress_type", &PSContext::set_download_compress_type, "Set download compress type.")
|
||||||
|
.def("download_compress_type", &PSContext::download_compress_type, "Get download compress type.")
|
||||||
.def("set_checkpoint_dir", &PSContext::set_checkpoint_dir, "Set server checkpoint directory.")
|
.def("set_checkpoint_dir", &PSContext::set_checkpoint_dir, "Set server checkpoint directory.")
|
||||||
.def("checkpoint_dir", &PSContext::checkpoint_dir, "Server checkpoint directory.");
|
.def("checkpoint_dir", &PSContext::checkpoint_dir, "Server checkpoint directory.");
|
||||||
(void)m.def("_encrypt", &mindspore::pipeline::PyEncrypt, "Encrypt the data.");
|
(void)m.def("_encrypt", &mindspore::pipeline::PyEncrypt, "Encrypt the data.");
|
||||||
|
|
|
@ -550,6 +550,19 @@ void PSContext::set_global_iteration_time_window(const uint64_t &global_iteratio
|
||||||
|
|
||||||
uint64_t PSContext::global_iteration_time_window() const { return global_iteration_time_window_; }
|
uint64_t PSContext::global_iteration_time_window() const { return global_iteration_time_window_; }
|
||||||
|
|
||||||
|
void PSContext::set_upload_compress_type(const std::string &upload_compress_type) {
|
||||||
|
upload_compress_type_ = upload_compress_type;
|
||||||
|
}
|
||||||
|
std::string PSContext::upload_compress_type() const { return upload_compress_type_; }
|
||||||
|
|
||||||
|
void PSContext::set_upload_sparse_rate(float upload_sparse_rate) { upload_sparse_rate_ = upload_sparse_rate; }
|
||||||
|
float PSContext::upload_sparse_rate() const { return upload_sparse_rate_; }
|
||||||
|
|
||||||
|
void PSContext::set_download_compress_type(const std::string &download_compress_type) {
|
||||||
|
download_compress_type_ = download_compress_type;
|
||||||
|
}
|
||||||
|
std::string PSContext::download_compress_type() const { return download_compress_type_; }
|
||||||
|
|
||||||
std::string PSContext::checkpoint_dir() const { return checkpoint_dir_; }
|
std::string PSContext::checkpoint_dir() const { return checkpoint_dir_; }
|
||||||
|
|
||||||
void PSContext::set_checkpoint_dir(const std::string &checkpoint_dir) { checkpoint_dir_ = checkpoint_dir; }
|
void PSContext::set_checkpoint_dir(const std::string &checkpoint_dir) { checkpoint_dir_ = checkpoint_dir; }
|
||||||
|
|
|
@ -40,6 +40,7 @@ constexpr char kPWEncryptType[] = "PW_ENCRYPT";
|
||||||
constexpr char kStablePWEncryptType[] = "STABLE_PW_ENCRYPT";
|
constexpr char kStablePWEncryptType[] = "STABLE_PW_ENCRYPT";
|
||||||
constexpr char kNotEncryptType[] = "NOT_ENCRYPT";
|
constexpr char kNotEncryptType[] = "NOT_ENCRYPT";
|
||||||
constexpr char kDSEncryptType[] = "SIGNDS";
|
constexpr char kDSEncryptType[] = "SIGNDS";
|
||||||
|
constexpr char kNoCompressType[] = "NO_COMPRESS";
|
||||||
|
|
||||||
// Use binary data to represent federated learning server's context so that we can judge which round resets the
|
// Use binary data to represent federated learning server's context so that we can judge which round resets the
|
||||||
// iteration. From right to left, each bit stands for:
|
// iteration. From right to left, each bit stands for:
|
||||||
|
@ -230,6 +231,15 @@ class PSContext {
|
||||||
void set_global_iteration_time_window(const uint64_t &global_iteration_time_window);
|
void set_global_iteration_time_window(const uint64_t &global_iteration_time_window);
|
||||||
uint64_t global_iteration_time_window() const;
|
uint64_t global_iteration_time_window() const;
|
||||||
|
|
||||||
|
void set_upload_compress_type(const std::string &upload_compress_type);
|
||||||
|
std::string upload_compress_type() const;
|
||||||
|
|
||||||
|
void set_upload_sparse_rate(float upload_sparse_rate);
|
||||||
|
float upload_sparse_rate() const;
|
||||||
|
|
||||||
|
void set_download_compress_type(const std::string &download_compress_type);
|
||||||
|
std::string download_compress_type() const;
|
||||||
|
|
||||||
std::string checkpoint_dir() const;
|
std::string checkpoint_dir() const;
|
||||||
void set_checkpoint_dir(const std::string &checkpoint_dir);
|
void set_checkpoint_dir(const std::string &checkpoint_dir);
|
||||||
|
|
||||||
|
@ -286,6 +296,9 @@ class PSContext {
|
||||||
server_password_(""),
|
server_password_(""),
|
||||||
http_url_prefix_(""),
|
http_url_prefix_(""),
|
||||||
global_iteration_time_window_(3600000),
|
global_iteration_time_window_(3600000),
|
||||||
|
upload_compress_type_(kNoCompressType),
|
||||||
|
upload_sparse_rate_(0.4f),
|
||||||
|
download_compress_type_(kNoCompressType),
|
||||||
checkpoint_dir_("") {}
|
checkpoint_dir_("") {}
|
||||||
bool ps_enabled_;
|
bool ps_enabled_;
|
||||||
bool is_worker_;
|
bool is_worker_;
|
||||||
|
@ -419,6 +432,13 @@ class PSContext {
|
||||||
|
|
||||||
// The time window of startFLJob round in millisecond.
|
// The time window of startFLJob round in millisecond.
|
||||||
uint64_t global_iteration_time_window_;
|
uint64_t global_iteration_time_window_;
|
||||||
|
|
||||||
|
// Hyper parameters for upload compression.
|
||||||
|
std::string upload_compress_type_;
|
||||||
|
float upload_sparse_rate_;
|
||||||
|
// Hyper parameters for download compression.
|
||||||
|
std::string download_compress_type_;
|
||||||
|
|
||||||
// directory of server checkpoint
|
// directory of server checkpoint
|
||||||
std::string checkpoint_dir_;
|
std::string checkpoint_dir_;
|
||||||
};
|
};
|
||||||
|
|
|
@ -105,6 +105,16 @@ public class FLLiteClient {
|
||||||
batchSize = flPlan.miniBatch();
|
batchSize = flPlan.miniBatch();
|
||||||
String serverMod = flPlan.serverMode();
|
String serverMod = flPlan.serverMode();
|
||||||
localFLParameter.setServerMod(serverMod);
|
localFLParameter.setServerMod(serverMod);
|
||||||
|
// Get and set hyper parameters for compression.
|
||||||
|
byte uploadCompressType = flJob.uploadCompressType();
|
||||||
|
LOGGER.info(Common.addTag("[startFLJob] [compression] uploadCompressType: " + uploadCompressType));
|
||||||
|
localFLParameter.setUploadCompressType(uploadCompressType);
|
||||||
|
float uploadSparseRate = flJob.uploadSparseRate();
|
||||||
|
LOGGER.info(Common.addTag("[startFLJob] [compression] uploadSparseRate: " + uploadSparseRate));
|
||||||
|
localFLParameter.setUploadSparseRatio(uploadSparseRate);
|
||||||
|
int seed = flJob.iteration();
|
||||||
|
LOGGER.info(Common.addTag("[startFLJob] [compression] seed: " + seed));
|
||||||
|
localFLParameter.setSeed(seed);
|
||||||
if (Common.checkFLName(flParameter.getFlName())) {
|
if (Common.checkFLName(flParameter.getFlName())) {
|
||||||
deprecatedSetBatchSize(batchSize);
|
deprecatedSetBatchSize(batchSize);
|
||||||
} else {
|
} else {
|
||||||
|
@ -446,7 +456,7 @@ public class FLLiteClient {
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
private Map<String, float[]> getFeatureMap() {
|
public Map<String, float[]> getFeatureMap() {
|
||||||
Map<String, float[]> featureMap = new HashMap<>();
|
Map<String, float[]> featureMap = new HashMap<>();
|
||||||
if (Common.checkFLName(flParameter.getFlName())) {
|
if (Common.checkFLName(flParameter.getFlName())) {
|
||||||
featureMap = deprecatedGetFeatureMap();
|
featureMap = deprecatedGetFeatureMap();
|
||||||
|
@ -530,8 +540,7 @@ public class FLLiteClient {
|
||||||
localFLParameter.getEncryptLevel().toString() + "> : " + curStatus));
|
localFLParameter.getEncryptLevel().toString() + "> : " + curStatus));
|
||||||
return curStatus;
|
return curStatus;
|
||||||
case DP_ENCRYPT:
|
case DP_ENCRYPT:
|
||||||
// get the feature map before train
|
oldFeatureMap = localFLParameter.getOldFeatureMap();
|
||||||
oldFeatureMap = getFeatureMap();
|
|
||||||
curStatus = secureProtocol.setDPParameter(iteration, dpEps, dpDelta, dpNormClipAdapt, oldFeatureMap);
|
curStatus = secureProtocol.setDPParameter(iteration, dpEps, dpDelta, dpNormClipAdapt, oldFeatureMap);
|
||||||
retCode = ResponseCode.SUCCEED;
|
retCode = ResponseCode.SUCCEED;
|
||||||
if (curStatus != FLClientStatus.SUCCESS) {
|
if (curStatus != FLClientStatus.SUCCESS) {
|
||||||
|
@ -542,8 +551,7 @@ public class FLLiteClient {
|
||||||
LOGGER.info(Common.addTag("[Encrypt] set parameters for DP_ENCRYPT!"));
|
LOGGER.info(Common.addTag("[Encrypt] set parameters for DP_ENCRYPT!"));
|
||||||
return FLClientStatus.SUCCESS;
|
return FLClientStatus.SUCCESS;
|
||||||
case SIGNDS:
|
case SIGNDS:
|
||||||
// get the feature map before train
|
oldFeatureMap = localFLParameter.getOldFeatureMap();
|
||||||
oldFeatureMap = getFeatureMap();
|
|
||||||
curStatus = secureProtocol.setDSParameter(signK, signEps, signThrRatio, signGlobalLr, signDimOut, oldFeatureMap);
|
curStatus = secureProtocol.setDSParameter(signK, signEps, signThrRatio, signGlobalLr, signDimOut, oldFeatureMap);
|
||||||
retCode = ResponseCode.SUCCEED;
|
retCode = ResponseCode.SUCCEED;
|
||||||
if (curStatus != FLClientStatus.SUCCESS) {
|
if (curStatus != FLClientStatus.SUCCESS) {
|
||||||
|
|
|
@ -18,7 +18,9 @@ package com.mindspore.flclient;
|
||||||
|
|
||||||
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||||
|
|
||||||
|
import com.mindspore.flclient.compression.CompressMode;
|
||||||
import com.mindspore.flclient.model.RunType;
|
import com.mindspore.flclient.model.RunType;
|
||||||
|
import mindspore.schema.CompressType;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
@ -603,6 +605,16 @@ public class FLParameter {
|
||||||
this.batchSize = batchSize;
|
this.batchSize = batchSize;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public byte[] getDownloadCompressTypes() {
|
||||||
|
byte[] downloadCompressTypes = new byte[CompressMode.COMPRESS_TYPE_MAP.size()];
|
||||||
|
int index = 0;
|
||||||
|
for (byte downloadCompressType : CompressMode.COMPRESS_TYPE_MAP.keySet()) {
|
||||||
|
downloadCompressTypes[index] = downloadCompressType;
|
||||||
|
index += 1;
|
||||||
|
}
|
||||||
|
return downloadCompressTypes;
|
||||||
|
}
|
||||||
|
|
||||||
public int[][] getInputShape() {
|
public int[][] getInputShape() {
|
||||||
return inputShape;
|
return inputShape;
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@ package com.mindspore.flclient;
|
||||||
|
|
||||||
import com.google.flatbuffers.FlatBufferBuilder;
|
import com.google.flatbuffers.FlatBufferBuilder;
|
||||||
|
|
||||||
|
import com.mindspore.flclient.compression.DecodeExecutor;
|
||||||
import com.mindspore.flclient.model.AlInferBert;
|
import com.mindspore.flclient.model.AlInferBert;
|
||||||
import com.mindspore.flclient.model.AlTrainBert;
|
import com.mindspore.flclient.model.AlTrainBert;
|
||||||
import com.mindspore.flclient.model.Client;
|
import com.mindspore.flclient.model.Client;
|
||||||
|
@ -27,11 +28,9 @@ import com.mindspore.flclient.model.SessionUtil;
|
||||||
import com.mindspore.flclient.model.Status;
|
import com.mindspore.flclient.model.Status;
|
||||||
|
|
||||||
import com.mindspore.flclient.model.TrainLenet;
|
import com.mindspore.flclient.model.TrainLenet;
|
||||||
import mindspore.schema.FeatureMap;
|
import mindspore.schema.*;
|
||||||
import mindspore.schema.RequestGetModel;
|
|
||||||
import mindspore.schema.ResponseCode;
|
|
||||||
import mindspore.schema.ResponseGetModel;
|
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Date;
|
import java.util.Date;
|
||||||
import java.util.logging.Logger;
|
import java.util.logging.Logger;
|
||||||
|
@ -94,7 +93,8 @@ public class GetModel {
|
||||||
throw new IllegalArgumentException();
|
throw new IllegalArgumentException();
|
||||||
}
|
}
|
||||||
RequestGetModelBuilder builder = new RequestGetModelBuilder();
|
RequestGetModelBuilder builder = new RequestGetModelBuilder();
|
||||||
return builder.iteration(iteration).flName(name).time().build();
|
return builder.iteration(iteration).flName(name).time()
|
||||||
|
.downloadCompressTypesBuilder(flParameter.getDownloadCompressTypes()).build();
|
||||||
}
|
}
|
||||||
|
|
||||||
private FLClientStatus deprecatedParseResponseAlbert(ResponseGetModel responseDataBuf) {
|
private FLClientStatus deprecatedParseResponseAlbert(ResponseGetModel responseDataBuf) {
|
||||||
|
@ -226,11 +226,29 @@ public class GetModel {
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private List<FeatureMap> parseFeatureMapList(ResponseGetModel responseDataBuf) {
|
||||||
|
List<FeatureMap> featureMaps;
|
||||||
|
byte compressType = responseDataBuf.downloadCompressType();
|
||||||
|
if (responseDataBuf.downloadCompressType() == mindspore.schema.CompressType.NO_COMPRESS) {
|
||||||
|
featureMaps = new ArrayList<>();
|
||||||
|
for (int i = 0; i < responseDataBuf.featureMapLength(); i++) {
|
||||||
|
featureMaps.add(responseDataBuf.featureMap(i));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
List<mindspore.schema.CompressFeatureMap> compressFeatureMapList = new ArrayList<>();
|
||||||
|
for (int i = 0; i < responseDataBuf.compressFeatureMapLength(); i++) {
|
||||||
|
compressFeatureMapList.add(responseDataBuf.compressFeatureMap(i));
|
||||||
|
}
|
||||||
|
featureMaps = DecodeExecutor.getInstance().deCompressWeight(compressType, compressFeatureMapList);
|
||||||
|
}
|
||||||
|
return featureMaps;
|
||||||
|
}
|
||||||
|
|
||||||
private FLClientStatus parseResponseFeatures(ResponseGetModel responseDataBuf) {
|
private FLClientStatus parseResponseFeatures(ResponseGetModel responseDataBuf) {
|
||||||
FLClientStatus status;
|
FLClientStatus status;
|
||||||
Client client = ClientManager.getClient(flParameter.getFlName());
|
Client client = ClientManager.getClient(flParameter.getFlName());
|
||||||
int fmCount = responseDataBuf.featureMapLength();
|
List<FeatureMap> featureMapList = parseFeatureMapList(responseDataBuf);
|
||||||
if (fmCount <= 0) {
|
if (featureMapList.size() <= 0) {
|
||||||
LOGGER.severe(Common.addTag("[getModel] the feature size get from server is zero"));
|
LOGGER.severe(Common.addTag("[getModel] the feature size get from server is zero"));
|
||||||
retCode = ResponseCode.SystemError;
|
retCode = ResponseCode.SystemError;
|
||||||
return FLClientStatus.FAILED;
|
return FLClientStatus.FAILED;
|
||||||
|
@ -239,8 +257,8 @@ public class GetModel {
|
||||||
LOGGER.info(Common.addTag("[getModel] parseResponseFeatures by " + localFLParameter.getServerMod()));
|
LOGGER.info(Common.addTag("[getModel] parseResponseFeatures by " + localFLParameter.getServerMod()));
|
||||||
ArrayList<FeatureMap> trainFeatureMaps = new ArrayList<FeatureMap>();
|
ArrayList<FeatureMap> trainFeatureMaps = new ArrayList<FeatureMap>();
|
||||||
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
|
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
|
||||||
for (int i = 0; i < fmCount; i++) {
|
for (int i = 0; i < featureMapList.size(); i++) {
|
||||||
FeatureMap feature = responseDataBuf.featureMap(i);
|
FeatureMap feature = featureMapList.get(i);
|
||||||
if (feature == null) {
|
if (feature == null) {
|
||||||
LOGGER.severe(Common.addTag("[getModel] the feature returned from server is null"));
|
LOGGER.severe(Common.addTag("[getModel] the feature returned from server is null"));
|
||||||
retCode = ResponseCode.SystemError;
|
retCode = ResponseCode.SystemError;
|
||||||
|
@ -289,8 +307,8 @@ public class GetModel {
|
||||||
} else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) {
|
} else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) {
|
||||||
LOGGER.info(Common.addTag("[getModel] parseResponseFeatures by " + localFLParameter.getServerMod()));
|
LOGGER.info(Common.addTag("[getModel] parseResponseFeatures by " + localFLParameter.getServerMod()));
|
||||||
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
|
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
|
||||||
for (int i = 0; i < fmCount; i++) {
|
for (int i = 0; i < featureMapList.size(); i++) {
|
||||||
FeatureMap feature = responseDataBuf.featureMap(i);
|
FeatureMap feature = featureMapList.get(i);
|
||||||
if (feature == null) {
|
if (feature == null) {
|
||||||
LOGGER.severe(Common.addTag("[getModel] the feature returned from server is null"));
|
LOGGER.severe(Common.addTag("[getModel] the feature returned from server is null"));
|
||||||
retCode = ResponseCode.SystemError;
|
retCode = ResponseCode.SystemError;
|
||||||
|
@ -365,6 +383,7 @@ public class GetModel {
|
||||||
private int nameOffset = 0;
|
private int nameOffset = 0;
|
||||||
private int iteration = 0;
|
private int iteration = 0;
|
||||||
private int timeStampOffset = 0;
|
private int timeStampOffset = 0;
|
||||||
|
private int downloadCompressTypesOffset = 0;
|
||||||
|
|
||||||
public RequestGetModelBuilder() {
|
public RequestGetModelBuilder() {
|
||||||
builder = new FlatBufferBuilder();
|
builder = new FlatBufferBuilder();
|
||||||
|
@ -392,11 +411,23 @@ public class GetModel {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private RequestGetModelBuilder downloadCompressTypesBuilder(byte[] downloadCompressTypes) {
|
||||||
|
if (downloadCompressTypes == null || downloadCompressTypes.length == 0) {
|
||||||
|
LOGGER.severe(Common.addTag("[GetModel] the parameter of <downloadCompressTypes> is null or empty," +
|
||||||
|
" please check!"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
this.downloadCompressTypesOffset = RequestGetModel.createDownloadCompressTypesVector(builder,
|
||||||
|
downloadCompressTypes);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
private byte[] build() {
|
private byte[] build() {
|
||||||
RequestGetModel.startRequestGetModel(builder);
|
RequestGetModel.startRequestGetModel(builder);
|
||||||
RequestGetModel.addFlName(builder, nameOffset);
|
RequestGetModel.addFlName(builder, nameOffset);
|
||||||
RequestGetModel.addIteration(builder, iteration);
|
RequestGetModel.addIteration(builder, iteration);
|
||||||
RequestGetModel.addTimestamp(builder, timeStampOffset);
|
RequestGetModel.addTimestamp(builder, timeStampOffset);
|
||||||
|
RequestGetModel.addDownloadCompressTypes(builder, downloadCompressTypesOffset);
|
||||||
int root = RequestGetModel.endRequestGetModel(builder);
|
int root = RequestGetModel.endRequestGetModel(builder);
|
||||||
builder.finish(root);
|
builder.finish(root);
|
||||||
return builder.sizedByteArray();
|
return builder.sizedByteArray();
|
||||||
|
|
|
@ -22,6 +22,7 @@ import org.bouncycastle.math.ec.rfc7748.X25519;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
import java.util.logging.Logger;
|
import java.util.logging.Logger;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -83,6 +84,10 @@ public class LocalFLParameter {
|
||||||
private MSConfig msConfig = new MSConfig();
|
private MSConfig msConfig = new MSConfig();
|
||||||
private boolean useSSL = true;
|
private boolean useSSL = true;
|
||||||
private float lr = 0.1f;
|
private float lr = 0.1f;
|
||||||
|
private Map<String, float[]> oldFeatureMap;
|
||||||
|
private byte uploadCompressType = 0;
|
||||||
|
private int seed = 0;
|
||||||
|
private float uploadSparseRatio = 0.08f;
|
||||||
|
|
||||||
|
|
||||||
private LocalFLParameter() {
|
private LocalFLParameter() {
|
||||||
|
@ -250,4 +255,36 @@ public class LocalFLParameter {
|
||||||
public void setLr(float lr) {
|
public void setLr(float lr) {
|
||||||
this.lr = lr;
|
this.lr = lr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Map<String, float[]> getOldFeatureMap() {
|
||||||
|
return oldFeatureMap;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setOldFeatureMap(Map<String, float[]> oldFeatureMap) {
|
||||||
|
this.oldFeatureMap = oldFeatureMap;
|
||||||
|
}
|
||||||
|
|
||||||
|
public byte getUploadCompressType() {
|
||||||
|
return uploadCompressType;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setUploadCompressType(byte uploadCompressType) {
|
||||||
|
this.uploadCompressType = uploadCompressType;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getSeed() {
|
||||||
|
return seed;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setSeed(int seed) {
|
||||||
|
this.seed = seed;
|
||||||
|
}
|
||||||
|
|
||||||
|
public float getUploadSparseRatio() {
|
||||||
|
return uploadSparseRatio;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setUploadSparseRatio(float uploadSparseRatio) {
|
||||||
|
this.uploadSparseRatio = uploadSparseRatio;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -208,35 +208,34 @@ public class SecureProtocol {
|
||||||
* @param trainDataSize trainDataSize tne size of train data set.
|
* @param trainDataSize trainDataSize tne size of train data set.
|
||||||
* @return the serialized model weights after adding masks.
|
* @return the serialized model weights after adding masks.
|
||||||
*/
|
*/
|
||||||
public int[] pwMaskModel(FlatBufferBuilder builder, int trainDataSize, Map<String, float[]> trainedMap) {
|
public Map<String, List<Float>> pwMaskModel(FlatBufferBuilder builder, int trainDataSize, Map<String,
|
||||||
|
float[]> trainedMap) {
|
||||||
|
Map<String, List<Float>> featureMaps = new HashMap<>();
|
||||||
if (featureMask == null || featureMask.length == 0) {
|
if (featureMask == null || featureMask.length == 0) {
|
||||||
LOGGER.severe("[Encrypt] feature mask is null, please check");
|
LOGGER.severe("[Encrypt] feature mask is null, please check");
|
||||||
return new int[0];
|
return new HashMap<>();
|
||||||
}
|
}
|
||||||
LOGGER.info(String.format("[Encrypt] feature mask size: %s", featureMask.length));
|
LOGGER.info(String.format("[Encrypt] feature mask size: %s", featureMask.length));
|
||||||
int featureSize = updateFeatureName.size();
|
int featureSize = updateFeatureName.size();
|
||||||
int[] featuresMap = new int[featureSize];
|
|
||||||
int maskIndex = 0;
|
int maskIndex = 0;
|
||||||
for (int i = 0; i < featureSize; i++) {
|
for (int i = 0; i < featureSize; i++) {
|
||||||
String key = updateFeatureName.get(i);
|
String key = updateFeatureName.get(i);
|
||||||
float[] data = trainedMap.get(key);
|
float[] data = trainedMap.get(key);
|
||||||
|
List<Float> featureMap = new ArrayList<>();
|
||||||
LOGGER.info(String.format("[Encrypt] feature name: %s feature size: %s", key, data.length));
|
LOGGER.info(String.format("[Encrypt] feature name: %s feature size: %s", key, data.length));
|
||||||
for (int j = 0; j < data.length; j++) {
|
for (int j = 0; j < data.length; j++) {
|
||||||
float rawData = data[j];
|
float rawData = data[j];
|
||||||
if (maskIndex >= featureMask.length) {
|
if (maskIndex >= featureMask.length) {
|
||||||
LOGGER.severe("[Encrypt] the maskIndex is out of range for array featureMask, please check");
|
LOGGER.severe("[Encrypt] the maskIndex is out of range for array featureMask, please check");
|
||||||
return new int[0];
|
return new HashMap<>();
|
||||||
}
|
}
|
||||||
float maskData = rawData * trainDataSize + featureMask[maskIndex];
|
float maskData = rawData * trainDataSize + featureMask[maskIndex];
|
||||||
maskIndex += 1;
|
maskIndex += 1;
|
||||||
data[j] = maskData;
|
featureMap.add(maskData);
|
||||||
}
|
}
|
||||||
int featureName = builder.createString(key);
|
featureMaps.put(key, featureMap);
|
||||||
int weight = FeatureMap.createDataVector(builder, data);
|
|
||||||
int featureMap = FeatureMap.createFeatureMap(builder, featureName, weight);
|
|
||||||
featuresMap[i] = featureMap;
|
|
||||||
}
|
}
|
||||||
return featuresMap;
|
return featureMaps;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -365,7 +364,9 @@ public class SecureProtocol {
|
||||||
* @param trainDataSize tne size of train data set.
|
* @param trainDataSize tne size of train data set.
|
||||||
* @return the serialized model weights after adding masks.
|
* @return the serialized model weights after adding masks.
|
||||||
*/
|
*/
|
||||||
public int[] dpMaskModel(FlatBufferBuilder builder, int trainDataSize, Map<String, float[]> trainedMap) {
|
public Map<String, List<Float>> dpMaskModel(FlatBufferBuilder builder, int trainDataSize,
|
||||||
|
Map<String, float[]> trainedMap) {
|
||||||
|
Map<String, List<Float>> featureMaps = new HashMap<>();
|
||||||
// get feature map
|
// get feature map
|
||||||
Map<String, float[]> mapBeforeTrain = modelMap;
|
Map<String, float[]> mapBeforeTrain = modelMap;
|
||||||
int featureSize = updateFeatureName.size();
|
int featureSize = updateFeatureName.size();
|
||||||
|
@ -383,7 +384,7 @@ public class SecureProtocol {
|
||||||
float rawData = data[j];
|
float rawData = data[j];
|
||||||
if (j >= dataBeforeTrain.length) {
|
if (j >= dataBeforeTrain.length) {
|
||||||
LOGGER.severe("[Encrypt] the index j is out of range for array dataBeforeTrain, please check");
|
LOGGER.severe("[Encrypt] the index j is out of range for array dataBeforeTrain, please check");
|
||||||
return new int[0];
|
return new HashMap<>();
|
||||||
}
|
}
|
||||||
float rawDataBeforeTrain = dataBeforeTrain[j];
|
float rawDataBeforeTrain = dataBeforeTrain[j];
|
||||||
float updateData = rawData - rawDataBeforeTrain;
|
float updateData = rawData - rawDataBeforeTrain;
|
||||||
|
@ -393,23 +394,23 @@ public class SecureProtocol {
|
||||||
updateL2Norm = Math.sqrt(updateL2Norm);
|
updateL2Norm = Math.sqrt(updateL2Norm);
|
||||||
if (updateL2Norm == 0) {
|
if (updateL2Norm == 0) {
|
||||||
LOGGER.severe(Common.addTag("[Encrypt] updateL2Norm is 0, please check"));
|
LOGGER.severe(Common.addTag("[Encrypt] updateL2Norm is 0, please check"));
|
||||||
return new int[0];
|
return new HashMap<>();
|
||||||
}
|
}
|
||||||
double clipFactor = Math.min(1.0, dpNormClip / updateL2Norm);
|
double clipFactor = Math.min(1.0, dpNormClip / updateL2Norm);
|
||||||
|
|
||||||
// clip and add noise
|
// clip and add noise
|
||||||
int[] featuresMap = new int[featureSize];
|
|
||||||
for (int i = 0; i < featureSize; i++) {
|
for (int i = 0; i < featureSize; i++) {
|
||||||
String key = updateFeatureName.get(i);
|
String key = updateFeatureName.get(i);
|
||||||
if (!trainedMap.containsKey(key)) {
|
if (!trainedMap.containsKey(key)) {
|
||||||
LOGGER.severe("[Encrypt] the key: " + key + " is not in map, please check!");
|
LOGGER.severe("[Encrypt] the key: " + key + " is not in map, please check!");
|
||||||
return new int[0];
|
return new HashMap<>();
|
||||||
}
|
}
|
||||||
float[] data = trainedMap.get(key);
|
float[] data = trainedMap.get(key);
|
||||||
float[] data2 = new float[data.length];
|
float[] data2 = new float[data.length];
|
||||||
|
List<Float> featureMap = new ArrayList<>();
|
||||||
if (!mapBeforeTrain.containsKey(key)) {
|
if (!mapBeforeTrain.containsKey(key)) {
|
||||||
LOGGER.severe("[Encrypt] the key: " + key + " is not in mapBeforeTrain, please check!");
|
LOGGER.severe("[Encrypt] the key: " + key + " is not in mapBeforeTrain, please check!");
|
||||||
return new int[0];
|
return new HashMap<>();
|
||||||
}
|
}
|
||||||
float[] dataBeforeTrain = mapBeforeTrain.get(key);
|
float[] dataBeforeTrain = mapBeforeTrain.get(key);
|
||||||
|
|
||||||
|
@ -419,7 +420,7 @@ public class SecureProtocol {
|
||||||
float rawData = data[j];
|
float rawData = data[j];
|
||||||
if (j >= dataBeforeTrain.length) {
|
if (j >= dataBeforeTrain.length) {
|
||||||
LOGGER.severe("[Encrypt] the index j is out of range for array dataBeforeTrain, please check");
|
LOGGER.severe("[Encrypt] the index j is out of range for array dataBeforeTrain, please check");
|
||||||
return new int[0];
|
return new HashMap<>();
|
||||||
}
|
}
|
||||||
float rawDataBeforeTrain = dataBeforeTrain[j];
|
float rawDataBeforeTrain = dataBeforeTrain[j];
|
||||||
float updateData = rawData - rawDataBeforeTrain;
|
float updateData = rawData - rawDataBeforeTrain;
|
||||||
|
@ -432,13 +433,11 @@ public class SecureProtocol {
|
||||||
updateData += gaussianNoise;
|
updateData += gaussianNoise;
|
||||||
data2[j] = rawDataBeforeTrain + updateData;
|
data2[j] = rawDataBeforeTrain + updateData;
|
||||||
data2[j] = data2[j] * trainDataSize;
|
data2[j] = data2[j] * trainDataSize;
|
||||||
|
featureMap.add(data2[j]);
|
||||||
}
|
}
|
||||||
int featureName = builder.createString(key);
|
featureMaps.put(key, featureMap);
|
||||||
int weight = FeatureMap.createDataVector(builder, data2);
|
|
||||||
int featureMap = FeatureMap.createFeatureMap(builder, featureName, weight);
|
|
||||||
featuresMap[i] = featureMap;
|
|
||||||
}
|
}
|
||||||
return featuresMap;
|
return featureMaps;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -18,6 +18,7 @@ package com.mindspore.flclient;
|
||||||
|
|
||||||
import com.google.flatbuffers.FlatBufferBuilder;
|
import com.google.flatbuffers.FlatBufferBuilder;
|
||||||
|
|
||||||
|
import com.mindspore.flclient.compression.DecodeExecutor;
|
||||||
import com.mindspore.flclient.model.AlInferBert;
|
import com.mindspore.flclient.model.AlInferBert;
|
||||||
import com.mindspore.flclient.model.AlTrainBert;
|
import com.mindspore.flclient.model.AlTrainBert;
|
||||||
import com.mindspore.flclient.model.Client;
|
import com.mindspore.flclient.model.Client;
|
||||||
|
@ -29,6 +30,7 @@ import com.mindspore.flclient.model.TrainLenet;
|
||||||
import com.mindspore.flclient.pki.PkiBean;
|
import com.mindspore.flclient.pki.PkiBean;
|
||||||
import com.mindspore.flclient.pki.PkiUtil;
|
import com.mindspore.flclient.pki.PkiUtil;
|
||||||
|
|
||||||
|
import mindspore.schema.*;
|
||||||
import mindspore.schema.FLPlan;
|
import mindspore.schema.FLPlan;
|
||||||
import mindspore.schema.FeatureMap;
|
import mindspore.schema.FeatureMap;
|
||||||
import mindspore.schema.RequestFLJob;
|
import mindspore.schema.RequestFLJob;
|
||||||
|
@ -38,6 +40,7 @@ import mindspore.schema.ResponseFLJob;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.security.cert.Certificate;
|
import java.security.cert.Certificate;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
import java.util.logging.Logger;
|
import java.util.logging.Logger;
|
||||||
|
|
||||||
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||||
|
@ -119,6 +122,7 @@ public class StartFLJob {
|
||||||
.iteration(iteration)
|
.iteration(iteration)
|
||||||
.signData(pkiBean.getSignData())
|
.signData(pkiBean.getSignData())
|
||||||
.certificateChain(pkiBean.getCertificates())
|
.certificateChain(pkiBean.getCertificates())
|
||||||
|
.downloadCompressTypesBuilder(flParameter.getDownloadCompressTypes())
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
return builder.flName(flParameter.getFlName())
|
return builder.flName(flParameter.getFlName())
|
||||||
|
@ -126,6 +130,7 @@ public class StartFLJob {
|
||||||
.id(localFLParameter.getFlID())
|
.id(localFLParameter.getFlID())
|
||||||
.dataSize(dataSize)
|
.dataSize(dataSize)
|
||||||
.iteration(iteration)
|
.iteration(iteration)
|
||||||
|
.downloadCompressTypesBuilder(flParameter.getDownloadCompressTypes())
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -151,8 +156,9 @@ public class StartFLJob {
|
||||||
ArrayList<FeatureMap> albertFeatureMaps = new ArrayList<FeatureMap>();
|
ArrayList<FeatureMap> albertFeatureMaps = new ArrayList<FeatureMap>();
|
||||||
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
|
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
|
||||||
featureSize = 0;
|
featureSize = 0;
|
||||||
for (int i = 0; i < fmCount; i++) {
|
List<FeatureMap> featureMapList = parseFeatureMapList(flJob);
|
||||||
FeatureMap feature = flJob.featureMap(i);
|
for (int i = 0; i < featureMapList.size(); i++) {
|
||||||
|
FeatureMap feature = featureMapList.get(i);
|
||||||
if (feature == null) {
|
if (feature == null) {
|
||||||
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
|
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
|
||||||
return FLClientStatus.FAILED;
|
return FLClientStatus.FAILED;
|
||||||
|
@ -233,12 +239,14 @@ public class StartFLJob {
|
||||||
|
|
||||||
private FLClientStatus deprecatedParseResponseLenet(ResponseFLJob flJob) {
|
private FLClientStatus deprecatedParseResponseLenet(ResponseFLJob flJob) {
|
||||||
FLClientStatus status;
|
FLClientStatus status;
|
||||||
int fmCount = flJob.featureMapLength();
|
|
||||||
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
|
|
||||||
updateFeatureName.clear();
|
updateFeatureName.clear();
|
||||||
featureSize = 0;
|
featureSize = 0;
|
||||||
for (int i = 0; i < fmCount; i++) {
|
List<FeatureMap> featureMapList = parseFeatureMapList(flJob);
|
||||||
FeatureMap feature = flJob.featureMap(i);
|
|
||||||
|
ArrayList<FeatureMap> featureMaps = new ArrayList<>();
|
||||||
|
|
||||||
|
for (int i = 0; i < featureMapList.size(); i++) {
|
||||||
|
FeatureMap feature = featureMapList.get(i);
|
||||||
if (feature == null) {
|
if (feature == null) {
|
||||||
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
|
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
|
||||||
return FLClientStatus.FAILED;
|
return FLClientStatus.FAILED;
|
||||||
|
@ -267,6 +275,24 @@ public class StartFLJob {
|
||||||
return FLClientStatus.SUCCESS;
|
return FLClientStatus.SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private List<FeatureMap> parseFeatureMapList(ResponseFLJob flJob) {
|
||||||
|
List<FeatureMap> featureMaps;
|
||||||
|
byte compressType = flJob.downloadCompressType();
|
||||||
|
if (flJob.downloadCompressType() == mindspore.schema.CompressType.NO_COMPRESS) {
|
||||||
|
LOGGER.info(Common.addTag("[parseFeatureMapList] create no compress feature map."));
|
||||||
|
featureMaps = new ArrayList<>();
|
||||||
|
for (int i = 0; i < flJob.featureMapLength(); i++) {
|
||||||
|
featureMaps.add(flJob.featureMap(i));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
List<CompressFeatureMap> compressFeatureMapList = new ArrayList<>();
|
||||||
|
for (int i = 0; i < flJob.compressFeatureMapLength(); i++) {
|
||||||
|
compressFeatureMapList.add(flJob.compressFeatureMap(i));
|
||||||
|
}
|
||||||
|
featureMaps = DecodeExecutor.getInstance().deCompressWeight(compressType, compressFeatureMapList);
|
||||||
|
}
|
||||||
|
return featureMaps;
|
||||||
|
}
|
||||||
|
|
||||||
private FLClientStatus hybridFeatures(ResponseFLJob flJob) {
|
private FLClientStatus hybridFeatures(ResponseFLJob flJob) {
|
||||||
FLClientStatus status;
|
FLClientStatus status;
|
||||||
|
@ -275,8 +301,23 @@ public class StartFLJob {
|
||||||
ArrayList<FeatureMap> trainFeatureMaps = new ArrayList<FeatureMap>();
|
ArrayList<FeatureMap> trainFeatureMaps = new ArrayList<FeatureMap>();
|
||||||
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
|
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
|
||||||
featureSize = 0;
|
featureSize = 0;
|
||||||
|
List<FeatureMap> featureMaps;
|
||||||
|
byte compressType = flJob.downloadCompressType();
|
||||||
|
if (compressType == CompressType.NO_COMPRESS) {
|
||||||
|
featureMaps = new ArrayList<>();
|
||||||
for (int i = 0; i < fmCount; i++) {
|
for (int i = 0; i < fmCount; i++) {
|
||||||
FeatureMap feature = flJob.featureMap(i);
|
featureMaps.add(flJob.featureMap(i));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
List<CompressFeatureMap> compressFeatureMapList = new ArrayList<>();
|
||||||
|
for (int i = 0; i < flJob.compressFeatureMapLength(); i++) {
|
||||||
|
compressFeatureMapList.add(flJob.compressFeatureMap(i));
|
||||||
|
}
|
||||||
|
featureMaps = DecodeExecutor.getInstance().deCompressWeight(compressType, compressFeatureMapList);
|
||||||
|
fmCount = featureMaps.size();
|
||||||
|
}
|
||||||
|
for (int i = 0; i < fmCount; i++) {
|
||||||
|
FeatureMap feature = featureMaps.get(i);
|
||||||
if (feature == null) {
|
if (feature == null) {
|
||||||
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
|
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
|
||||||
retCode = ResponseCode.SystemError;
|
retCode = ResponseCode.SystemError;
|
||||||
|
@ -335,8 +376,23 @@ public class StartFLJob {
|
||||||
int fmCount = flJob.featureMapLength();
|
int fmCount = flJob.featureMapLength();
|
||||||
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
|
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
|
||||||
featureSize = 0;
|
featureSize = 0;
|
||||||
|
byte compressType = flJob.downloadCompressType();
|
||||||
|
List<FeatureMap> parseFeatureMaps;
|
||||||
|
if (compressType == CompressType.NO_COMPRESS) {
|
||||||
|
parseFeatureMaps = new ArrayList<>();
|
||||||
for (int i = 0; i < fmCount; i++) {
|
for (int i = 0; i < fmCount; i++) {
|
||||||
FeatureMap feature = flJob.featureMap(i);
|
parseFeatureMaps.add(flJob.featureMap(i));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
List<CompressFeatureMap> compressFeatureMapList = new ArrayList<>();
|
||||||
|
for (int i = 0; i < flJob.compressFeatureMapLength(); i++) {
|
||||||
|
compressFeatureMapList.add(flJob.compressFeatureMap(i));
|
||||||
|
}
|
||||||
|
parseFeatureMaps = DecodeExecutor.getInstance().deCompressWeight(compressType, compressFeatureMapList);
|
||||||
|
fmCount = parseFeatureMaps.size();
|
||||||
|
}
|
||||||
|
for (int i = 0; i < fmCount; i++) {
|
||||||
|
FeatureMap feature = parseFeatureMaps.get(i);
|
||||||
if (feature == null) {
|
if (feature == null) {
|
||||||
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
|
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
|
||||||
retCode = ResponseCode.SystemError;
|
retCode = ResponseCode.SystemError;
|
||||||
|
@ -437,8 +493,8 @@ public class StartFLJob {
|
||||||
|
|
||||||
switch (responseRetCode) {
|
switch (responseRetCode) {
|
||||||
case (ResponseCode.SUCCEED):
|
case (ResponseCode.SUCCEED):
|
||||||
if (flJob.featureMapLength() <= 0) {
|
if (flJob.downloadCompressType() == CompressType.NO_COMPRESS && flJob.featureMapLength() <= 0) {
|
||||||
LOGGER.severe(Common.addTag("[startFLJob] the feature size get from server is zero"));
|
LOGGER.warning(Common.addTag("[startFLJob] the feature size get from server is zero"));
|
||||||
retCode = ResponseCode.SystemError;
|
retCode = ResponseCode.SystemError;
|
||||||
return FLClientStatus.FAILED;
|
return FLClientStatus.FAILED;
|
||||||
}
|
}
|
||||||
|
@ -484,6 +540,7 @@ public class StartFLJob {
|
||||||
private int equipCertOffset = 0;
|
private int equipCertOffset = 0;
|
||||||
private int equipCACertOffset = 0;
|
private int equipCACertOffset = 0;
|
||||||
private int rootCertOffset = 0;
|
private int rootCertOffset = 0;
|
||||||
|
private int downloadCompressTypesOffset = 0;
|
||||||
|
|
||||||
public RequestStartFLJobBuilder() {
|
public RequestStartFLJobBuilder() {
|
||||||
builder = new FlatBufferBuilder();
|
builder = new FlatBufferBuilder();
|
||||||
|
@ -598,6 +655,17 @@ public class StartFLJob {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private RequestStartFLJobBuilder downloadCompressTypesBuilder(byte[] downloadCompressTypes) {
|
||||||
|
if (downloadCompressTypes == null || downloadCompressTypes.length == 0) {
|
||||||
|
LOGGER.severe(Common.addTag("[StartFLJob] the parameter of <downloadCompressTypes> is null or empty," +
|
||||||
|
" please check!"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
this.downloadCompressTypesOffset = RequestFLJob.createDownloadCompressTypesVector(builder,
|
||||||
|
downloadCompressTypes);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* build protobuffer
|
* build protobuffer
|
||||||
*
|
*
|
||||||
|
@ -615,6 +683,7 @@ public class StartFLJob {
|
||||||
RequestFLJob.addEquipCaCert(builder, equipCACertOffset);
|
RequestFLJob.addEquipCaCert(builder, equipCACertOffset);
|
||||||
RequestFLJob.addEquipCert(builder, equipCertOffset);
|
RequestFLJob.addEquipCert(builder, equipCertOffset);
|
||||||
RequestFLJob.addKeyAttestation(builder, keyAttestationOffset);
|
RequestFLJob.addKeyAttestation(builder, keyAttestationOffset);
|
||||||
|
RequestFLJob.addDownloadCompressTypes(builder, downloadCompressTypesOffset);
|
||||||
int root = RequestFLJob.endRequestFLJob(builder);
|
int root = RequestFLJob.endRequestFLJob(builder);
|
||||||
builder.finish(root);
|
builder.finish(root);
|
||||||
return builder.sizedByteArray();
|
return builder.sizedByteArray();
|
||||||
|
|
|
@ -147,6 +147,10 @@ public class SyncFLJob {
|
||||||
LOGGER.info(Common.addTag("[startFLJob] startFLJob succeed, curIteration: " + flLiteClient.getIteration()));
|
LOGGER.info(Common.addTag("[startFLJob] startFLJob succeed, curIteration: " + flLiteClient.getIteration()));
|
||||||
updateTryTimePerIter(flLiteClient);
|
updateTryTimePerIter(flLiteClient);
|
||||||
|
|
||||||
|
// Copy weights before training.
|
||||||
|
Map<String, float[]> oldFeatureMap = flLiteClient.getFeatureMap();
|
||||||
|
localFLParameter.setOldFeatureMap(oldFeatureMap);
|
||||||
|
|
||||||
// create mask
|
// create mask
|
||||||
curStatus = flLiteClient.getFeatureMask();
|
curStatus = flLiteClient.getFeatureMask();
|
||||||
if (curStatus == FLClientStatus.RESTART) {
|
if (curStatus == FLClientStatus.RESTART) {
|
||||||
|
|
|
@ -26,11 +26,15 @@ import com.mindspore.flclient.model.SessionUtil;
|
||||||
import com.mindspore.flclient.model.Status;
|
import com.mindspore.flclient.model.Status;
|
||||||
import com.mindspore.flclient.model.TrainLenet;
|
import com.mindspore.flclient.model.TrainLenet;
|
||||||
import com.mindspore.lite.MSTensor;
|
import com.mindspore.lite.MSTensor;
|
||||||
|
import com.mindspore.flclient.compression.EncodeExecutor;
|
||||||
|
import com.mindspore.flclient.compression.CompressWeight;
|
||||||
|
|
||||||
import mindspore.schema.FeatureMap;
|
import mindspore.schema.FeatureMap;
|
||||||
|
import mindspore.schema.CompressFeatureMap;
|
||||||
import mindspore.schema.RequestUpdateModel;
|
import mindspore.schema.RequestUpdateModel;
|
||||||
import mindspore.schema.ResponseCode;
|
import mindspore.schema.ResponseCode;
|
||||||
import mindspore.schema.ResponseUpdateModel;
|
import mindspore.schema.ResponseUpdateModel;
|
||||||
|
import static mindspore.schema.CompressType.NO_COMPRESS;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Date;
|
import java.util.Date;
|
||||||
|
@ -208,6 +212,7 @@ public class UpdateModel {
|
||||||
private RequestUpdateModel requestUM;
|
private RequestUpdateModel requestUM;
|
||||||
private FlatBufferBuilder builder;
|
private FlatBufferBuilder builder;
|
||||||
private int fmOffset = 0;
|
private int fmOffset = 0;
|
||||||
|
private int compFmOffset = 0;
|
||||||
private int nameOffset = 0;
|
private int nameOffset = 0;
|
||||||
private int idOffset = 0;
|
private int idOffset = 0;
|
||||||
private int timestampOffset = 0;
|
private int timestampOffset = 0;
|
||||||
|
@ -215,8 +220,11 @@ public class UpdateModel {
|
||||||
private int sign = 0;
|
private int sign = 0;
|
||||||
private int indexArrayOffset = 0;
|
private int indexArrayOffset = 0;
|
||||||
private int iteration = 0;
|
private int iteration = 0;
|
||||||
|
private byte uploadCompressType = 0;
|
||||||
|
private float uploadSparseRate = 0.0f;
|
||||||
private EncryptLevel encryptLevel = EncryptLevel.NOT_ENCRYPT;
|
private EncryptLevel encryptLevel = EncryptLevel.NOT_ENCRYPT;
|
||||||
private float uploadLossOffset = 0.0f;
|
private float uploadLossOffset = 0.0f;
|
||||||
|
private int nameVecOffset = 0;
|
||||||
|
|
||||||
private RequestUpdateModelBuilder(EncryptLevel encryptLevel) {
|
private RequestUpdateModelBuilder(EncryptLevel encryptLevel) {
|
||||||
builder = new FlatBufferBuilder();
|
builder = new FlatBufferBuilder();
|
||||||
|
@ -294,34 +302,33 @@ public class UpdateModel {
|
||||||
} else {
|
} else {
|
||||||
trainedMap = getFeatureMap();
|
trainedMap = getFeatureMap();
|
||||||
}
|
}
|
||||||
|
Map<String, List<Float>> featureMaps = new HashMap<>();
|
||||||
long startTime;
|
long startTime;
|
||||||
long endTime;
|
long endTime;
|
||||||
switch (encryptLevel) {
|
switch (encryptLevel) {
|
||||||
case PW_ENCRYPT:
|
case PW_ENCRYPT:
|
||||||
int[] fmOffsetsPW = secureProtocol.pwMaskModel(builder, trainDataSize, trainedMap);
|
featureMaps = secureProtocol.pwMaskModel(builder, trainDataSize, trainedMap);
|
||||||
if (fmOffsetsPW == null || fmOffsetsPW.length == 0) {
|
if (featureMaps == null || featureMaps.size() == 0) {
|
||||||
LOGGER.severe("[Encrypt] the return fmOffsetsPW from <secureProtocol.pwMaskModel> is " +
|
LOGGER.severe("[Encrypt] the return featureMaps from <secureProtocol.pwMaskModel> is " +
|
||||||
"null, please check");
|
"null, please check");
|
||||||
throw new IllegalArgumentException();
|
throw new IllegalArgumentException();
|
||||||
}
|
}
|
||||||
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsPW);
|
|
||||||
LOGGER.info(Common.addTag("[Encrypt] pairwise mask model ok!"));
|
LOGGER.info(Common.addTag("[Encrypt] pairwise mask model ok!"));
|
||||||
return this;
|
break;
|
||||||
case DP_ENCRYPT:
|
case DP_ENCRYPT:
|
||||||
startTime = System.currentTimeMillis();
|
startTime = System.currentTimeMillis();
|
||||||
int[] fmOffsetsDP = secureProtocol.dpMaskModel(builder, trainDataSize, trainedMap);
|
featureMaps = secureProtocol.dpMaskModel(builder, trainDataSize, trainedMap);
|
||||||
if (fmOffsetsDP == null || fmOffsetsDP.length == 0) {
|
if (featureMaps == null || featureMaps.size() == 0) {
|
||||||
LOGGER.severe("[Encrypt] the return fmOffsetsDP from <secureProtocol.dpMaskModel> is " +
|
LOGGER.severe("[Encrypt] the return featureMaps from <secureProtocol.dpMaskModel> is " +
|
||||||
"null, please check");
|
"null, please check");
|
||||||
retCode = ResponseCode.RequestError;
|
retCode = ResponseCode.RequestError;
|
||||||
status = FLClientStatus.FAILED;
|
status = FLClientStatus.FAILED;
|
||||||
throw new IllegalArgumentException();
|
throw new IllegalArgumentException();
|
||||||
}
|
}
|
||||||
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsDP);
|
|
||||||
LOGGER.info(Common.addTag("[Encrypt] DP mask model ok!"));
|
LOGGER.info(Common.addTag("[Encrypt] DP mask model ok!"));
|
||||||
endTime = System.currentTimeMillis();
|
endTime = System.currentTimeMillis();
|
||||||
LOGGER.info(Common.addTag("[Encrypt] dp time is: " + (endTime - startTime) + "ms"));
|
LOGGER.info(Common.addTag("dp time is " + (endTime - startTime) + "ms"));
|
||||||
return this;
|
break;
|
||||||
case SIGNDS:
|
case SIGNDS:
|
||||||
startTime = System.currentTimeMillis();
|
startTime = System.currentTimeMillis();
|
||||||
// signds alg return indexArray, and package indexArray into flatbuffer.
|
// signds alg return indexArray, and package indexArray into flatbuffer.
|
||||||
|
@ -352,31 +359,104 @@ public class UpdateModel {
|
||||||
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsSignds);
|
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsSignds);
|
||||||
LOGGER.info(Common.addTag("[Encrypt] SignDS mask model ok!"));
|
LOGGER.info(Common.addTag("[Encrypt] SignDS mask model ok!"));
|
||||||
endTime = System.currentTimeMillis();
|
endTime = System.currentTimeMillis();
|
||||||
LOGGER.info(Common.addTag("[Encrypt] signds time is: " + (endTime - startTime) + "ms"));
|
LOGGER.info(Common.addTag("signds time is " + (endTime - startTime) + "ms"));
|
||||||
return this;
|
return this;
|
||||||
case NOT_ENCRYPT:
|
case NOT_ENCRYPT:
|
||||||
default:
|
default:
|
||||||
startTime = System.currentTimeMillis();
|
startTime = System.currentTimeMillis();
|
||||||
|
for (String name : updateFeatureName) {
|
||||||
|
float[] data = trainedMap.get(name);
|
||||||
|
List<Float> featureMap = new ArrayList<>();
|
||||||
|
for (float datum : data) {
|
||||||
|
featureMap.add(datum * (float) trainDataSize);
|
||||||
|
}
|
||||||
|
featureMaps.put(name, featureMap);
|
||||||
|
}
|
||||||
|
endTime = System.currentTimeMillis();
|
||||||
|
LOGGER.info(Common.addTag("not encrypt time is " + (endTime - startTime) + "ms"));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
byte uploadCompressType = localFLParameter.getUploadCompressType();
|
||||||
|
if (uploadCompressType != NO_COMPRESS) {
|
||||||
|
startTime = System.currentTimeMillis();
|
||||||
|
this.compFmOffset = buildCompFmOffset(featureMaps, trainDataSize);
|
||||||
|
this.uploadCompressType = localFLParameter.getUploadCompressType();
|
||||||
|
this.uploadSparseRate = localFLParameter.getUploadSparseRatio();
|
||||||
|
this.nameVecOffset = buildNameVecOffset(updateFeatureName);
|
||||||
|
endTime = System.currentTimeMillis();
|
||||||
|
LOGGER.info(Common.addTag("compression time is " + (endTime - startTime) + "ms"));
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
this.fmOffset = buildFmOffset(featureMaps, updateFeatureName);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
private int buildCompFmOffset(Map<String, List<Float>> featureMaps, int trainDataSize) {
|
||||||
|
List<CompressWeight> compressWeights = EncodeExecutor.getInstance().encode(featureMaps, trainDataSize);
|
||||||
|
if (compressWeights == null || compressWeights.size() == 0) {
|
||||||
|
LOGGER.severe("[Compression] the return compressWeights from <encodeExecutor.encode> is " +
|
||||||
|
"null, please check");
|
||||||
|
retCode = ResponseCode.RequestError;
|
||||||
|
status = FLClientStatus.FAILED;
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
int compFeatureSize = compressWeights.size();
|
||||||
|
int[] compFmOffsets = new int[compFeatureSize];
|
||||||
|
int index = 0;
|
||||||
|
for (CompressWeight compressWeight : compressWeights) {
|
||||||
|
String weightFullname = compressWeight.getWeightFullname();
|
||||||
|
List<Byte> compressData = compressWeight.getCompressData();
|
||||||
|
float minVal = compressWeight.getMinValue();
|
||||||
|
float maxVal = compressWeight.getMaxValue();
|
||||||
|
byte[] data = new byte[compressData.size()];
|
||||||
|
LOGGER.info(Common.addTag("[updateModel build compressWeight] feature name: "
|
||||||
|
+ weightFullname + ", feature size: " + data.length));
|
||||||
|
for (int j = 0; j < data.length; j++) {
|
||||||
|
data[j] = compressData.get(j);
|
||||||
|
}
|
||||||
|
int featureName = builder.createString(weightFullname);
|
||||||
|
int weight = CompressFeatureMap.createCompressDataVector(builder, data);
|
||||||
|
int featureMap = CompressFeatureMap.createCompressFeatureMap(builder, featureName, weight,
|
||||||
|
minVal, maxVal);
|
||||||
|
LOGGER.info(Common.addTag("[Compression]" +
|
||||||
|
" featureName: " + weightFullname +
|
||||||
|
", min_val: " + minVal +
|
||||||
|
", max_val: " + maxVal));
|
||||||
|
compFmOffsets[index] = featureMap;
|
||||||
|
index += 1;
|
||||||
|
}
|
||||||
|
return RequestUpdateModel.createCompressFeatureMapVector(builder, compFmOffsets);
|
||||||
|
}
|
||||||
|
|
||||||
|
private int buildNameVecOffset(ArrayList<String> updateFeatureName) {
|
||||||
|
int featureSize = updateFeatureName.size();
|
||||||
|
int[] nameVecOffsets = new int[featureSize];
|
||||||
|
for (int i = 0; i < featureSize; i++) {
|
||||||
|
String key = updateFeatureName.get(i);
|
||||||
|
int featureName = builder.createString(key);
|
||||||
|
nameVecOffsets[i] = featureName;
|
||||||
|
}
|
||||||
|
return RequestUpdateModel.createNameVecVector(builder, nameVecOffsets);
|
||||||
|
}
|
||||||
|
|
||||||
|
private int buildFmOffset(Map<String, List<Float>> featureMaps, ArrayList<String> updateFeatureName) {
|
||||||
int featureSize = updateFeatureName.size();
|
int featureSize = updateFeatureName.size();
|
||||||
int[] fmOffsets = new int[featureSize];
|
int[] fmOffsets = new int[featureSize];
|
||||||
for (int i = 0; i < featureSize; i++) {
|
for (int i = 0; i < featureSize; i++) {
|
||||||
String key = updateFeatureName.get(i);
|
String key = updateFeatureName.get(i);
|
||||||
float[] data = trainedMap.get(key);
|
List<Float> featureMap = featureMaps.get(key);
|
||||||
|
float[] data = new float[featureMap.size()];
|
||||||
LOGGER.info(Common.addTag("[updateModel build featuresMap] feature name: " + key + " feature " +
|
LOGGER.info(Common.addTag("[updateModel build featuresMap] feature name: " + key + " feature " +
|
||||||
"size: " + data.length));
|
"size: " + data.length));
|
||||||
for (int j = 0; j < data.length; j++) {
|
for (int j = 0; j < data.length; j++) {
|
||||||
data[j] = data[j] * trainDataSize;
|
data[j] = featureMap.get(j);
|
||||||
}
|
}
|
||||||
int featureName = builder.createString(key);
|
int featureName = builder.createString(key);
|
||||||
int weight = FeatureMap.createDataVector(builder, data);
|
int weight = FeatureMap.createDataVector(builder, data);
|
||||||
int featureMap = FeatureMap.createFeatureMap(builder, featureName, weight);
|
int featureMapOff = FeatureMap.createFeatureMap(builder, featureName, weight);
|
||||||
fmOffsets[i] = featureMap;
|
fmOffsets[i] = featureMapOff;
|
||||||
}
|
|
||||||
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsets);
|
|
||||||
endTime = System.currentTimeMillis();
|
|
||||||
LOGGER.info(Common.addTag("[Encrypt] not encrypt time is: " + (endTime - startTime) + "ms"));
|
|
||||||
return this;
|
|
||||||
}
|
}
|
||||||
|
return RequestUpdateModel.createFeatureMapVector(builder, fmOffsets);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -417,6 +497,10 @@ public class UpdateModel {
|
||||||
RequestUpdateModel.addFlId(this.builder, idOffset);
|
RequestUpdateModel.addFlId(this.builder, idOffset);
|
||||||
RequestUpdateModel.addTimestamp(builder, this.timestampOffset);
|
RequestUpdateModel.addTimestamp(builder, this.timestampOffset);
|
||||||
RequestUpdateModel.addIteration(builder, this.iteration);
|
RequestUpdateModel.addIteration(builder, this.iteration);
|
||||||
|
RequestUpdateModel.addCompressFeatureMap(builder, this.compFmOffset);
|
||||||
|
RequestUpdateModel.addUploadCompressType(builder, this.uploadCompressType);
|
||||||
|
RequestUpdateModel.addUploadSparseRate(builder, this.uploadSparseRate);
|
||||||
|
RequestUpdateModel.addNameVec(builder, this.nameVecOffset);
|
||||||
RequestUpdateModel.addFeatureMap(builder, this.fmOffset);
|
RequestUpdateModel.addFeatureMap(builder, this.fmOffset);
|
||||||
RequestUpdateModel.addSignature(builder, this.signDataOffset);
|
RequestUpdateModel.addSignature(builder, this.signDataOffset);
|
||||||
RequestUpdateModel.addUploadLoss(builder, this.uploadLossOffset);
|
RequestUpdateModel.addUploadLoss(builder, this.uploadLossOffset);
|
||||||
|
|
|
@ -0,0 +1,40 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2022. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package com.mindspore.flclient.compression;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import static mindspore.schema.CompressType.NO_COMPRESS;
|
||||||
|
import static mindspore.schema.CompressType.QUANT;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The compress mod.
|
||||||
|
*
|
||||||
|
* @since 2021-12-21
|
||||||
|
*/
|
||||||
|
|
||||||
|
public class CompressMode {
|
||||||
|
// compress type -> num bits
|
||||||
|
public static final Map<Byte, Integer> COMPRESS_TYPE_MAP = new HashMap<>();
|
||||||
|
|
||||||
|
static {
|
||||||
|
COMPRESS_TYPE_MAP.put(NO_COMPRESS, -1);
|
||||||
|
COMPRESS_TYPE_MAP.put(QUANT, 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,83 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2022. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package com.mindspore.flclient.compression;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compress Weight Bean
|
||||||
|
*
|
||||||
|
* @since 2021-12-21
|
||||||
|
*/
|
||||||
|
public class CompressWeight {
|
||||||
|
private String weightFullname;
|
||||||
|
private List<Byte> compressData;
|
||||||
|
private float minValue;
|
||||||
|
private float maxValue;
|
||||||
|
|
||||||
|
public CompressWeight() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public CompressWeight(String weightFullname, List<Byte> compressData, float minValue, float maxValue) {
|
||||||
|
this.weightFullname = weightFullname;
|
||||||
|
this.compressData = compressData;
|
||||||
|
this.minValue = minValue;
|
||||||
|
this.maxValue = maxValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getWeightFullname() {
|
||||||
|
return weightFullname;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setWeightFullname(String weightFullname) {
|
||||||
|
this.weightFullname = weightFullname;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<Byte> getCompressData() {
|
||||||
|
return compressData;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setCompressData(List<Byte> compressData) {
|
||||||
|
this.compressData = compressData;
|
||||||
|
}
|
||||||
|
|
||||||
|
public float getMinValue() {
|
||||||
|
return minValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setMinValue(float minValue) {
|
||||||
|
this.minValue = minValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
public float getMaxValue() {
|
||||||
|
return maxValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setMaxValue(float maxValue) {
|
||||||
|
this.maxValue = maxValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "CompressWeight{" +
|
||||||
|
"weightFullname='" + weightFullname + '\'' +
|
||||||
|
", compressData=" + compressData +
|
||||||
|
", minValue=" + minValue +
|
||||||
|
", maxValue=" + maxValue +
|
||||||
|
'}';
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,115 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2022. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package com.mindspore.flclient.compression;
|
||||||
|
|
||||||
|
import com.google.flatbuffers.FlatBufferBuilder;
|
||||||
|
|
||||||
|
import com.mindspore.flclient.Common;
|
||||||
|
import com.mindspore.flclient.StartFLJob;
|
||||||
|
import mindspore.schema.CompressFeatureMap;
|
||||||
|
import mindspore.schema.FeatureMap;
|
||||||
|
import mindspore.schema.CompressType;
|
||||||
|
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.logging.Logger;
|
||||||
|
|
||||||
|
import static mindspore.schema.CompressType.QUANT;
|
||||||
|
/**
|
||||||
|
* Compress Executor
|
||||||
|
*
|
||||||
|
* @since 2021-12-21
|
||||||
|
*/
|
||||||
|
public class DecodeExecutor {
|
||||||
|
private static final Logger LOGGER = Logger.getLogger(DecodeExecutor.class.toString());
|
||||||
|
|
||||||
|
private static volatile DecodeExecutor compressExecutor;
|
||||||
|
|
||||||
|
private DecodeExecutor() {}
|
||||||
|
|
||||||
|
public static DecodeExecutor getInstance() {
|
||||||
|
if (compressExecutor == null) {
|
||||||
|
synchronized (DecodeExecutor.class) {
|
||||||
|
if (compressExecutor == null) {
|
||||||
|
compressExecutor = new DecodeExecutor();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return compressExecutor;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<FeatureMap> deCompressWeight(byte compressType, List<CompressFeatureMap> compressFeatureMapList) {
|
||||||
|
if (!CompressMode.COMPRESS_TYPE_MAP.containsKey(compressType)) {
|
||||||
|
return new ArrayList<>();
|
||||||
|
}
|
||||||
|
LOGGER.info(Common.addTag("[deCompressWeight] create " + CompressType.name(compressType) + " feature map."));
|
||||||
|
int num_bits = CompressMode.COMPRESS_TYPE_MAP.get(compressType);
|
||||||
|
if (compressType == QUANT) {
|
||||||
|
return deCompressQuantMinMax(compressFeatureMapList, num_bits);
|
||||||
|
}
|
||||||
|
return new ArrayList<>();
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<FeatureMap> deCompressQuantMinMax(List<CompressFeatureMap> compressFeatureMapList, int num_bits) {
|
||||||
|
float temp1 = (float) (Math.pow(2, num_bits) - 1);
|
||||||
|
float temp2 = (float) Math.pow(2, num_bits - 1);
|
||||||
|
|
||||||
|
Map<String, float[]> deCompressFeatureMaps = new HashMap<>();
|
||||||
|
int compressFeatureMapLength = compressFeatureMapList.size();
|
||||||
|
for (int i = 0; i < compressFeatureMapLength; i++) {
|
||||||
|
CompressFeatureMap compressFeatureMap = compressFeatureMapList.get(i);
|
||||||
|
String weightName = compressFeatureMap.weightFullname();
|
||||||
|
int compressDataLength = compressFeatureMap.compressDataLength();
|
||||||
|
List<Byte> compressWeightList = new ArrayList<>();
|
||||||
|
for (int j = 0; j < compressDataLength; j++) {
|
||||||
|
compressWeightList.add(compressFeatureMap.compressData(j));
|
||||||
|
}
|
||||||
|
float minVal = compressFeatureMap.minVal();
|
||||||
|
float maxVal = compressFeatureMap.maxVal();
|
||||||
|
float scale_value = (float) ((maxVal - minVal) / temp1 + 1e-10);
|
||||||
|
float[] params = new float[compressWeightList.size()];
|
||||||
|
for (int j = 0; j < params.length; j++) {
|
||||||
|
float val = (compressWeightList.get(j).intValue() + temp2) * scale_value + minVal;
|
||||||
|
params[j] = val;
|
||||||
|
}
|
||||||
|
deCompressFeatureMaps.put(weightName, params);
|
||||||
|
}
|
||||||
|
|
||||||
|
List<FeatureMap> featureMaps = new ArrayList<>();
|
||||||
|
for (String weightName : deCompressFeatureMaps.keySet()) {
|
||||||
|
FlatBufferBuilder builder = new FlatBufferBuilder(0);
|
||||||
|
int weightFullnameOffset = builder.createString(weightName);
|
||||||
|
float[] data = deCompressFeatureMaps.get(weightName);
|
||||||
|
int dataOffset = FeatureMap.createDataVector(builder, data);
|
||||||
|
|
||||||
|
FeatureMap.startFeatureMap(builder);
|
||||||
|
FeatureMap.addWeightFullname(builder, weightFullnameOffset);
|
||||||
|
FeatureMap.addData(builder, dataOffset);
|
||||||
|
|
||||||
|
int orc = FeatureMap.endFeatureMap(builder);
|
||||||
|
builder.finish(orc);
|
||||||
|
ByteBuffer buf = builder.dataBuffer();
|
||||||
|
FeatureMap featureMap = FeatureMap.getRootAsFeatureMap(buf);
|
||||||
|
|
||||||
|
featureMaps.add(featureMap);
|
||||||
|
}
|
||||||
|
return featureMaps;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,167 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2022. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package com.mindspore.flclient.compression;
|
||||||
|
|
||||||
|
import com.mindspore.flclient.LocalFLParameter;
|
||||||
|
import static mindspore.schema.CompressType.DIFF_SPARSE_QUANT;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Queue;
|
||||||
|
import java.util.PriorityQueue;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Encode Executor
|
||||||
|
*
|
||||||
|
* @since 2021-12-21
|
||||||
|
*/
|
||||||
|
public class EncodeExecutor {
|
||||||
|
private final LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
||||||
|
|
||||||
|
private static volatile EncodeExecutor encodeExecutor;
|
||||||
|
|
||||||
|
private EncodeExecutor() {}
|
||||||
|
|
||||||
|
public static EncodeExecutor getInstance() {
|
||||||
|
if (encodeExecutor == null) {
|
||||||
|
synchronized (EncodeExecutor.class) {
|
||||||
|
if (encodeExecutor == null) {
|
||||||
|
encodeExecutor = new EncodeExecutor();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return encodeExecutor;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static final int multiplier = 2147483647;
|
||||||
|
private static final double increment = 4294967294.0;
|
||||||
|
private static final int modulo = 48271;
|
||||||
|
|
||||||
|
private List<Integer> constructMaskArray(int paramNum) {
|
||||||
|
int seed = localFLParameter.getSeed();
|
||||||
|
float uploadSparseRatio = localFLParameter.getUploadSparseRatio();
|
||||||
|
|
||||||
|
List<Integer> maskArray = new ArrayList<>();
|
||||||
|
|
||||||
|
int retain_num = (int) ((float) (paramNum) * uploadSparseRatio);
|
||||||
|
for (int i = 0; i < retain_num; ++i) {
|
||||||
|
maskArray.add(1);
|
||||||
|
}
|
||||||
|
for (int i = retain_num; i < paramNum; ++i) {
|
||||||
|
maskArray.add(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
seed = ((seed + multiplier) * modulo) % multiplier;
|
||||||
|
for (int i = 0; i < paramNum; ++i) {
|
||||||
|
// generate random number in (0, 1)
|
||||||
|
double rand = (double)(seed) / increment + 0.5;
|
||||||
|
// update seed
|
||||||
|
seed = (seed * modulo) % multiplier;
|
||||||
|
|
||||||
|
int j = (int)(rand * (double)(paramNum - i)) + i;
|
||||||
|
int temp = maskArray.get(i);
|
||||||
|
maskArray.set(i, maskArray.get(j));
|
||||||
|
maskArray.set(j, temp);
|
||||||
|
}
|
||||||
|
return maskArray;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<CompressWeight> enDiffSparseQuant(Map<String, List<Float>> featureMaps, int numBits,
|
||||||
|
int trainDataSize) {
|
||||||
|
List<CompressWeight> compressWeights = new ArrayList<>();
|
||||||
|
|
||||||
|
// difference encode
|
||||||
|
Map<String, float[]> oldFeatureMap = localFLParameter.getOldFeatureMap();
|
||||||
|
Map<String, List<Float>> diffFeatureMaps = new HashMap<>();
|
||||||
|
for (String featureMapName : featureMaps.keySet()) {
|
||||||
|
List<Float> diffs = new ArrayList<>();
|
||||||
|
List<Float> featureMap = featureMaps.get(featureMapName);
|
||||||
|
float[] dataBeforeTrain = oldFeatureMap.get(featureMapName);
|
||||||
|
int length = dataBeforeTrain.length;
|
||||||
|
for (int i = 0; i < length; ++i) {
|
||||||
|
float diff = featureMap.get(i) - dataBeforeTrain[i] * (float) trainDataSize;
|
||||||
|
diffs.add(diff);
|
||||||
|
}
|
||||||
|
diffFeatureMaps.put(featureMapName, diffs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// sparse encode
|
||||||
|
int paramNum = 0;
|
||||||
|
for (String featureMapName : diffFeatureMaps.keySet()) {
|
||||||
|
int weightSize = diffFeatureMaps.get(featureMapName).size();
|
||||||
|
paramNum += weightSize;
|
||||||
|
}
|
||||||
|
List<Integer> maskArray = constructMaskArray(paramNum);
|
||||||
|
|
||||||
|
Map<String, List<Float>> sparseFeatureMaps = new HashMap<>();
|
||||||
|
int index = 0;
|
||||||
|
for (String featureMapName : diffFeatureMaps.keySet()) {
|
||||||
|
List<Float> sparseFeatureMap = new ArrayList<>();
|
||||||
|
List<Float> Weight = diffFeatureMaps.get(featureMapName);
|
||||||
|
for (Float dataValue : Weight) {
|
||||||
|
if (maskArray.get(index) == 1) {
|
||||||
|
sparseFeatureMap.add(dataValue);
|
||||||
|
}
|
||||||
|
index += 1;
|
||||||
|
}
|
||||||
|
sparseFeatureMaps.put(featureMapName, sparseFeatureMap);
|
||||||
|
}
|
||||||
|
|
||||||
|
// quant encode
|
||||||
|
float temp1 = (float) (1 << numBits) - 1.0f;
|
||||||
|
float temp2 = (float) (1 << (numBits - 1));
|
||||||
|
for (String featureMapName : sparseFeatureMaps.keySet()) {
|
||||||
|
CompressWeight compressWeight = new CompressWeight();
|
||||||
|
compressWeight.setWeightFullname(featureMapName);
|
||||||
|
|
||||||
|
List<Float> sparseFeatureMap = sparseFeatureMaps.get(featureMapName);
|
||||||
|
|
||||||
|
// get min and max value
|
||||||
|
Float minVal = Float.MAX_VALUE;
|
||||||
|
float maxVal = -minVal;
|
||||||
|
for (Float value : sparseFeatureMap) {
|
||||||
|
if (value < minVal) {
|
||||||
|
minVal = value;
|
||||||
|
}
|
||||||
|
if (value > maxVal) {
|
||||||
|
maxVal = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
compressWeight.setMinValue(minVal);
|
||||||
|
compressWeight.setMaxValue(maxVal);
|
||||||
|
float scale_value = (maxVal - minVal) / temp1 + 1e-10f;
|
||||||
|
List<Byte> compressData = new ArrayList<>();
|
||||||
|
for (Float aFloat : sparseFeatureMap) {
|
||||||
|
compressData.add((byte) (Math.round((aFloat - minVal) / scale_value - temp2)));
|
||||||
|
}
|
||||||
|
compressWeight.setCompressData(compressData);
|
||||||
|
compressWeights.add(compressWeight);
|
||||||
|
}
|
||||||
|
|
||||||
|
return compressWeights;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<CompressWeight> encode(Map<String, List<Float>> featureMaps, int trainDataSize) {
|
||||||
|
byte uploadCompressType = localFLParameter.getUploadCompressType();
|
||||||
|
if (uploadCompressType == DIFF_SPARSE_QUANT) {
|
||||||
|
return enDiffSparseQuant(featureMaps, 8, trainDataSize);
|
||||||
|
}
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
}
|
|
@ -15,8 +15,9 @@
|
||||||
"""Context for parameter server training mode"""
|
"""Context for parameter server training mode"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from mindspore._checkparam import Validator
|
from mindspore._checkparam import Validator, Rel
|
||||||
from mindspore._c_expression import PSContext
|
from mindspore._c_expression import PSContext
|
||||||
|
from mindspore import log as logger
|
||||||
|
|
||||||
_ps_context = None
|
_ps_context = None
|
||||||
|
|
||||||
|
@ -79,6 +80,9 @@ _set_ps_context_func_map = {
|
||||||
"sign_global_lr": ps_context().set_sign_global_lr,
|
"sign_global_lr": ps_context().set_sign_global_lr,
|
||||||
"sign_dim_out": ps_context().set_sign_dim_out,
|
"sign_dim_out": ps_context().set_sign_dim_out,
|
||||||
"checkpoint_dir": ps_context().set_checkpoint_dir,
|
"checkpoint_dir": ps_context().set_checkpoint_dir,
|
||||||
|
"upload_compress_type": ps_context().set_upload_compress_type,
|
||||||
|
"upload_sparse_rate": ps_context().set_upload_sparse_rate,
|
||||||
|
"download_compress_type": ps_context().set_download_compress_type,
|
||||||
}
|
}
|
||||||
|
|
||||||
_get_ps_context_func_map = {
|
_get_ps_context_func_map = {
|
||||||
|
@ -126,7 +130,10 @@ _get_ps_context_func_map = {
|
||||||
"sign_thr_ratio": ps_context().sign_thr_ratio,
|
"sign_thr_ratio": ps_context().sign_thr_ratio,
|
||||||
"sign_global_lr": ps_context().sign_global_lr,
|
"sign_global_lr": ps_context().sign_global_lr,
|
||||||
"sign_dim_out": ps_context().sign_dim_out,
|
"sign_dim_out": ps_context().sign_dim_out,
|
||||||
"checkpoint_dir": ps_context().checkpoint_dir
|
"checkpoint_dir": ps_context().checkpoint_dir,
|
||||||
|
"upload_compress_type": ps_context().upload_compress_type,
|
||||||
|
"upload_sparse_rate": ps_context().upload_sparse_rate,
|
||||||
|
"download_compress_type": ps_context().download_compress_type,
|
||||||
}
|
}
|
||||||
|
|
||||||
_check_positive_int_keys = ["server_num", "scheduler_port", "fl_server_port",
|
_check_positive_int_keys = ["server_num", "scheduler_port", "fl_server_port",
|
||||||
|
@ -140,6 +147,15 @@ _check_positive_float_keys = ["update_model_ratio", "client_learning_rate"]
|
||||||
|
|
||||||
_check_port_keys = ["scheduler_port", "fl_server_port"]
|
_check_port_keys = ["scheduler_port", "fl_server_port"]
|
||||||
|
|
||||||
|
_check_string_keys = {
|
||||||
|
"upload_compress_type": ["NO_COMPRESS", "DIFF_SPARSE_QUANT"],
|
||||||
|
"download_compress_type": ["NO_COMPRESS", "QUANT"],
|
||||||
|
}
|
||||||
|
|
||||||
|
_check_float_range_keys = {
|
||||||
|
"upload_sparse_rate": {"lower_limit": 0.0, "upper_limit": 1.0, "rel": Rel.INC_RIGHT},
|
||||||
|
}
|
||||||
|
|
||||||
def _get_ps_mode_rank():
|
def _get_ps_mode_rank():
|
||||||
ps_rank = ps_context().ps_rank_id()
|
ps_rank = ps_context().ps_rank_id()
|
||||||
if ps_rank == -1:
|
if ps_rank == -1:
|
||||||
|
@ -183,6 +199,7 @@ def _set_ps_context(**kwargs):
|
||||||
Examples:
|
Examples:
|
||||||
>>> context.set_ps_context(enable_ps=True, enable_ssl=True, client_password='123456', server_password='123456')
|
>>> context.set_ps_context(enable_ps=True, enable_ssl=True, client_password='123456', server_password='123456')
|
||||||
"""
|
"""
|
||||||
|
kwargs = _check_conflict_value(kwargs)
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
if key not in _set_ps_context_func_map:
|
if key not in _set_ps_context_func_map:
|
||||||
raise ValueError("Set PS context keyword %s is not recognized!" % key)
|
raise ValueError("Set PS context keyword %s is not recognized!" % key)
|
||||||
|
@ -287,6 +304,31 @@ def _check_value(key, value):
|
||||||
if key in _check_positive_float_keys:
|
if key in _check_positive_float_keys:
|
||||||
Validator.check_positive_float(value, key)
|
Validator.check_positive_float(value, key)
|
||||||
|
|
||||||
|
if key in _check_string_keys:
|
||||||
|
try:
|
||||||
|
string_keys = _check_string_keys[key]
|
||||||
|
Validator.check_string(value, string_keys)
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if key in _check_float_range_keys:
|
||||||
|
try:
|
||||||
|
range_keys = _check_float_range_keys[key]
|
||||||
|
Validator.check_float_range(value, **range_keys)
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
if key in _check_port_keys:
|
if key in _check_port_keys:
|
||||||
if value < 1 or value > 65535:
|
if value < 1 or value > 65535:
|
||||||
raise ValueError("The range of %s must be 1 to 65535, but got %d." % (key, value))
|
raise ValueError("The range of %s must be 1 to 65535, but got %d." % (key, value))
|
||||||
|
|
||||||
|
|
||||||
|
def _check_conflict_value(kwargs):
|
||||||
|
if "upload_compress_type" in kwargs and " encrypt_type" in kwargs:
|
||||||
|
if kwargs["upload_compress_type"] != "NO_COMPRESS" and kwargs["encrypt_type"] in ("SIGNDS", "PW_ENCRYPT"):
|
||||||
|
logger.warning("The '{}' and '{}' are conflicted, and in '{}' mode the"
|
||||||
|
" 'upload_compress_type' will be 'NO_COMPRESS'".format(kwargs["encrypt_type"],
|
||||||
|
kwargs["upload_compress_type"],
|
||||||
|
kwargs["encrypt_type"]))
|
||||||
|
kwargs["upload_compress_type"] = "NO_COMPRESS"
|
||||||
|
return kwargs
|
||||||
|
|
|
@ -47,6 +47,16 @@ table FeatureMap{
|
||||||
weight_fullname:string;
|
weight_fullname:string;
|
||||||
data:[float];
|
data:[float];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum CompressType:byte {NO_COMPRESS = 0, DIFF_SPARSE_QUANT = 1, QUANT = 2}
|
||||||
|
|
||||||
|
table CompressFeatureMap{
|
||||||
|
weight_fullname:string;
|
||||||
|
compress_data:[int8];
|
||||||
|
min_val:float;
|
||||||
|
max_val:float;
|
||||||
|
}
|
||||||
|
|
||||||
table RequestFLJob{
|
table RequestFLJob{
|
||||||
fl_name:string;
|
fl_name:string;
|
||||||
fl_id:string;
|
fl_id:string;
|
||||||
|
@ -58,6 +68,7 @@ table RequestFLJob{
|
||||||
equip_cert:string;
|
equip_cert:string;
|
||||||
equip_ca_cert:string;
|
equip_ca_cert:string;
|
||||||
root_cert:string;
|
root_cert:string;
|
||||||
|
download_compress_types:[CompressType];
|
||||||
}
|
}
|
||||||
table ResponseFLJob {
|
table ResponseFLJob {
|
||||||
retcode:int;
|
retcode:int;
|
||||||
|
@ -68,6 +79,10 @@ table ResponseFLJob {
|
||||||
fl_plan_config:FLPlan;
|
fl_plan_config:FLPlan;
|
||||||
feature_map:[FeatureMap];
|
feature_map:[FeatureMap];
|
||||||
timestamp:string;
|
timestamp:string;
|
||||||
|
upload_compress_type:CompressType;
|
||||||
|
upload_sparse_rate:float;
|
||||||
|
download_compress_type:CompressType;
|
||||||
|
compress_feature_map:[CompressFeatureMap];
|
||||||
}
|
}
|
||||||
|
|
||||||
table FLPlan {
|
table FLPlan {
|
||||||
|
@ -94,6 +109,10 @@ table RequestUpdateModel{
|
||||||
upload_loss:float;
|
upload_loss:float;
|
||||||
sign:int;
|
sign:int;
|
||||||
index_array:[int];
|
index_array:[int];
|
||||||
|
compress_feature_map:[CompressFeatureMap];
|
||||||
|
upload_compress_type:CompressType;
|
||||||
|
upload_sparse_rate:float;
|
||||||
|
name_vec:[string];
|
||||||
}
|
}
|
||||||
|
|
||||||
table ResponseUpdateModel{
|
table ResponseUpdateModel{
|
||||||
|
@ -132,6 +151,7 @@ table RequestGetModel{
|
||||||
fl_name:string;
|
fl_name:string;
|
||||||
iteration:int;
|
iteration:int;
|
||||||
timestamp:string;
|
timestamp:string;
|
||||||
|
download_compress_types:[CompressType];
|
||||||
}
|
}
|
||||||
table ResponseGetModel{
|
table ResponseGetModel{
|
||||||
retcode:int;
|
retcode:int;
|
||||||
|
@ -139,6 +159,8 @@ table ResponseGetModel{
|
||||||
iteration:int;
|
iteration:int;
|
||||||
feature_map:[FeatureMap];
|
feature_map:[FeatureMap];
|
||||||
timestamp:string;
|
timestamp:string;
|
||||||
|
download_compress_type:CompressType;
|
||||||
|
compress_feature_map:[CompressFeatureMap];
|
||||||
}
|
}
|
||||||
|
|
||||||
table RequestAsyncGetModel{
|
table RequestAsyncGetModel{
|
||||||
|
|
Loading…
Reference in New Issue