diff --git a/mindspore/ccsrc/fl/CMakeLists.txt b/mindspore/ccsrc/fl/CMakeLists.txt index 178fd76f35b..ae137801118 100644 --- a/mindspore/ccsrc/fl/CMakeLists.txt +++ b/mindspore/ccsrc/fl/CMakeLists.txt @@ -51,6 +51,8 @@ if(NOT ENABLE_CPU OR WIN32) list(REMOVE_ITEM _FL_SRC_FILES "armour/cipher/cipher_reconstruct.cc") list(REMOVE_ITEM _FL_SRC_FILES "armour/cipher/cipher_shares.cc") list(REMOVE_ITEM _FL_SRC_FILES "armour/cipher/cipher_unmask.cc") + list(REMOVE_ITEM _FL_SRC_FILES "compression/decode_executor.cc") + list(REMOVE_ITEM _FL_SRC_FILES "compression/encode_executor.cc") endif() if(CMAKE_SYSTEM_NAME MATCHES "Darwin") diff --git a/mindspore/ccsrc/fl/compression/decode_executor.cc b/mindspore/ccsrc/fl/compression/decode_executor.cc new file mode 100644 index 00000000000..ac55791176e --- /dev/null +++ b/mindspore/ccsrc/fl/compression/decode_executor.cc @@ -0,0 +1,150 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "fl/compression/decode_executor.h" + +namespace mindspore { +namespace fl { +namespace compression { +std::vector DecodeExecutor::ConstructMaskArray(int seed, float upload_sparse_rate, size_t param_num) { + static int multiplier = 2147483647; + static double increment = 4294967294.0; + static int modulo = 48271; + size_t retain_num = size_t(static_cast(param_num) * upload_sparse_rate); + if (retain_num == 0) { + MS_LOG(WARNING) << "The retain_num is 0, and upload_sparse_rate is too small."; + } + std::vector mask_array(param_num, 0); + for (size_t i = 0; i < retain_num; ++i) { + mask_array[i] = 1; + } + + seed = ((seed + multiplier) * modulo) % multiplier; + for (size_t i = 0; i < param_num; ++i) { + // generate random number in (0, 1) + double rand = static_cast(seed) / increment + 0.5; + // update seed + seed = (seed * modulo) % multiplier; + size_t j = size_t(rand * static_cast(param_num - i)) + i; + int temp = mask_array[i]; + mask_array[i] = mask_array[j]; + mask_array[j] = temp; + } + return mask_array; +} + +bool DecodeExecutor::DeQuantSparseDiff(std::map> *weight_map, + const std::vector &compress_feature_maps, size_t num_bits, + float upload_sparse_rate, int seed, const std::vector &name_vec, + size_t data_size) { + std::vector> decompress_feature_maps; + + // origin parameters + std::vector shape_vec; + size_t param_num = 0; + const auto &iter_to_model = mindspore::fl::server::ModelStore::GetInstance().iteration_to_model(); + size_t latest_iter_num = iter_to_model.rbegin()->first; + std::map feature_maps = + mindspore::fl::server::ModelStore::GetInstance().GetModelByIterNum(latest_iter_num); + // get shape vector and number of upload parameters + for (const auto &name : name_vec) { + size_t shape = feature_maps[name]->size / sizeof(float); + shape_vec.emplace_back(shape); + param_num += shape; + } + MS_LOG(DEBUG) << "Compression get last weights success!"; + + // quant decode + auto temp1 = static_cast(1 << num_bits) - 1.0f; + auto temp2 = static_cast(1 << (num_bits - 1)); + std::vector de_min_max_feature_map; + for (auto compress_feature_map : compress_feature_maps) { + float min_val = compress_feature_map.min_val; + float max_val = compress_feature_map.max_val; + float scale_val = static_cast(max_val - min_val) / temp1 + 1e-10f; + size_t size = compress_feature_map.compress_data.size(); + for (size_t i = 0; i < size; ++i) { + de_min_max_feature_map.emplace_back( + (static_cast(compress_feature_map.compress_data[i]) + temp2) * scale_val + min_val); + } + } + MS_LOG(DEBUG) << "Compression quant decode success!"; + + // sparse decode + std::vector mask_array = ConstructMaskArray(seed, upload_sparse_rate, param_num); + size_t index = 0; + size_t de_min_max_feature_map_index = 0; + for (const auto &shape : shape_vec) { + std::vector feature_map(shape); + for (size_t i = 0; i < shape; ++i) { + if (index >= mask_array.size()) { + MS_LOG(WARNING) << "The mask_array and parameter shape is not matched."; + return false; + } + if (mask_array[index] == 1) { + if (de_min_max_feature_map_index >= de_min_max_feature_map.size()) { + MS_LOG(WARNING) << "The number of upload parameters is too small."; + return false; + } + feature_map[i] = de_min_max_feature_map[de_min_max_feature_map_index]; + de_min_max_feature_map_index += 1; + } else { + feature_map[i] = 0.0f; + } + index += 1; + } + decompress_feature_maps.emplace_back(feature_map); + } + MS_LOG(DEBUG) << "Compression sparse decode success!"; + + // difference decode + for (size_t i = 0; i < decompress_feature_maps.size(); ++i) { + size_t feature_size = decompress_feature_maps[i].size(); + std::string name = name_vec[i]; + float *weight_data = reinterpret_cast(feature_maps[name]->addr); + auto &weight_item = (*weight_map)[name]; + weight_item.resize(feature_size); + for (size_t j = 0; j < feature_size; ++j) { + weight_item[j] = decompress_feature_maps[i][j] + data_size * weight_data[j]; + } + } + MS_LOG(DEBUG) << "Compression difference decode success!"; + + return true; +} + +bool DecodeExecutor::Decode(std::map> *weight_map, + const std::vector &compress_feature_maps, + schema::CompressType upload_compress_type, float upload_sparse_rate, int seed, + const std::vector &name_vec, size_t data_size) { + if (upload_compress_type == schema::CompressType_DIFF_SPARSE_QUANT) { + return DeQuantSparseDiff(weight_map, compress_feature_maps, 8, upload_sparse_rate, seed, name_vec, data_size); + } + return false; +} + +schema::CompressType DecodeExecutor::GetCompressType(schema::CompressType upload_compress_type) { + if (upload_compress_type == schema::CompressType_DIFF_SPARSE_QUANT) { + MS_LOG(DEBUG) << "This upload compress type is DIFF_SPARSE_QUANT."; + return schema::CompressType_DIFF_SPARSE_QUANT; + } + + MS_LOG(DEBUG) << "This upload compress type is NO_COMPRESS."; + return schema::CompressType_NO_COMPRESS; +} +} // namespace compression +} // namespace fl +} // namespace mindspore diff --git a/mindspore/ccsrc/fl/compression/decode_executor.h b/mindspore/ccsrc/fl/compression/decode_executor.h new file mode 100644 index 00000000000..9ef8da68c51 --- /dev/null +++ b/mindspore/ccsrc/fl/compression/decode_executor.h @@ -0,0 +1,74 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FL_COMPRESSION_DECODE_EXECUTOR_H_ +#define MINDSPORE_CCSRC_FL_COMPRESSION_DECODE_EXECUTOR_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "proto/comm.pb.h" +#include "schema/fl_job_generated.h" +#include "schema/cipher_generated.h" +#include "fl/server/model_store.h" +#include "fl/server/common.h" +#include "ps/ps_context.h" + +namespace mindspore { +namespace fl { +namespace compression { +struct CompressFeatureMap { + std::string weight_fullname; + std::vector compress_data; + float min_val; + float max_val; +}; + +class DecodeExecutor { + public: + static DecodeExecutor &GetInstance() { + static DecodeExecutor instance; + return instance; + } + + // construct mask array for random sparse + std::vector ConstructMaskArray(int seed, float upload_sparse_rate, size_t param_num); + + // decode min_max quantization and random sparse and parameter difference + bool DeQuantSparseDiff(std::map> *weight_map, + const std::vector &compress_feature_maps, size_t num_bits, + float upload_sparse_rate, int seed, const std::vector &name_vec, + size_t data_size); + + // decode + bool Decode(std::map> *weight_map, + const std::vector &compress_feature_maps, schema::CompressType upload_compress_type, + float upload_sparse_rate, int seed, const std::vector &name_vec, size_t data_size); + + schema::CompressType GetCompressType(schema::CompressType upload_compress_type); +}; +} // namespace compression +} // namespace fl +} // namespace mindspore +#endif // MINDSPORE_CCSRC_FL_COMPRESSION_DECODE_EXECUTOR_H_ diff --git a/mindspore/ccsrc/fl/compression/encode_executor.cc b/mindspore/ccsrc/fl/compression/encode_executor.cc new file mode 100644 index 00000000000..d329278d6e7 --- /dev/null +++ b/mindspore/ccsrc/fl/compression/encode_executor.cc @@ -0,0 +1,102 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "fl/compression/encode_executor.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "fl/server/common.h" + +namespace mindspore { +namespace fl { +namespace compression { +bool CompressExecutor::EnableCompressWeight(const schema::CompressType compressType) { + return kCompressTypeMap.count(compressType) > 0; +} + +bool CompressExecutor::construct_compress_weight(std::map *compressWeights, + std::map> feature_maps, + const schema::CompressType compressType) { + if (compressType == schema::CompressType_QUANT) { + return quant_min_max(compressWeights, feature_maps, kCompressTypeMap.at(compressType)); + } + return false; +} + +bool CompressExecutor::quant_min_max(std::map *compressWeights, + std::map> feature_maps, size_t num_bits) { + auto temp1 = static_cast(1 << num_bits) - 1.0f; + auto temp2 = static_cast(1 << (num_bits - 1)); + for (const auto &feature_map : feature_maps) { + std::string weight_name = feature_map.first; + float min_value = 1e10f; + float max_value = -min_value; + for (const auto &feature : feature_map.second) { + if (feature > max_value) { + max_value = feature; + } + if (feature < min_value) { + min_value = feature; + } + } + float scale_value = (max_value - min_value) / temp1 + 1e-10f; + size_t size = feature_map.second.size(); + if (size == 0) { + MS_LOG(WARNING) << "The size of parameters is zero."; + return false; + } + CompressWeight compressWeight; + for (size_t i = 0; i < size; ++i) { + auto round_data = round((feature_map.second[i] - min_value) / scale_value - temp2); + // bit pack can be implemented here in the future + auto int8_data = int8_t(round_data); + compressWeight.compress_data.emplace_back(int8_data); + } + compressWeight.min_val = min_value; + compressWeight.max_val = max_value; + compressWeight.compress_data_len = size; + + (*compressWeights)[weight_name] = compressWeight; + } + return true; +} + +schema::CompressType CompressExecutor::GetCompressType(const flatbuffers::Vector *download_compress_types) { + schema::CompressType compressType = schema::CompressType_NO_COMPRESS; + if (download_compress_types == nullptr) { + MS_LOG(DEBUG) << "The client does not support current download compress type."; + } else { + for (size_t i = 0; i < download_compress_types->size(); ++i) { + auto download_compress_type = download_compress_types->Get(i); + if (download_compress_type == schema::CompressType_QUANT) { + compressType = schema::CompressType_QUANT; + break; + } + } + } + return compressType; +} +} // namespace compression +} // namespace fl +} // namespace mindspore diff --git a/mindspore/ccsrc/fl/compression/encode_executor.h b/mindspore/ccsrc/fl/compression/encode_executor.h new file mode 100644 index 00000000000..7ec5a7257ba --- /dev/null +++ b/mindspore/ccsrc/fl/compression/encode_executor.h @@ -0,0 +1,68 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FL_COMPRESSION_ENCODE_EXECUTOR_H_ +#define MINDSPORE_CCSRC_FL_COMPRESSION_ENCODE_EXECUTOR_H_ + +#include +#include +#include +#include +#include "proto/comm.pb.h" +#include "schema/fl_job_generated.h" +#include "schema/cipher_generated.h" +#include "fl/armour/secure_protocol/key_agreement.h" +#include "ps/ps_context.h" +#include "ps/core/worker_node.h" +#include "ps/core/cluster_metadata.h" +#include "ps/core/communicator/tcp_communicator.h" +#include "fl/server/common.h" + +namespace mindspore { +namespace fl { +namespace compression { +// compress type map: schema::CompressType -> num bits +const std::map kCompressTypeMap = {{schema::CompressType_QUANT, 8}}; + +struct CompressWeight { + std::vector compress_data; + size_t compress_data_len; + float min_val; + float max_val; +}; + +class CompressExecutor { + public: + static CompressExecutor &GetInstance() { + static CompressExecutor instance; + return instance; + } + + bool EnableCompressWeight(const schema::CompressType compressType); + + bool construct_compress_weight(std::map *compressWeights, + std::map> feature_maps, + const schema::CompressType compressType); + + bool quant_min_max(std::map *compressWeights, + std::map> feature_maps, size_t num_bits); + + schema::CompressType GetCompressType(const flatbuffers::Vector *download_compress_types); +}; +} // namespace compression +} // namespace fl +} // namespace mindspore +#endif // MINDSPORE_CCSRC_FL_COMPRESSION_ENCODE_EXECUTOR_H_ diff --git a/mindspore/ccsrc/fl/server/common.h b/mindspore/ccsrc/fl/server/common.h index f029d91c9e3..f88cccb52a1 100644 --- a/mindspore/ccsrc/fl/server/common.h +++ b/mindspore/ccsrc/fl/server/common.h @@ -149,6 +149,11 @@ constexpr auto kUpdateModelRejectClientNum = "updateModelRejectClientNum"; constexpr auto kGetModelTotalClientNum = "getModelTotalClientNum"; constexpr auto kGetModelAcceptClientNum = "getModelAcceptClientNum"; constexpr auto kGetModelRejectClientNum = "getModelRejectClientNum"; +constexpr auto kMinVal = "min_val"; +constexpr auto kMaxVal = "max_val"; +constexpr auto kQuant = "QUANT"; +constexpr auto kDiffSparseQuant = "DIFF_SPARSE_QUANT"; +constexpr auto kNoCompress = "NO_COMPRESS"; // OptimParamNameToIndex represents every inputs/workspace/outputs parameter's offset when an optimizer kernel is // launched. diff --git a/mindspore/ccsrc/fl/server/iteration.cc b/mindspore/ccsrc/fl/server/iteration.cc index 5b666f2c50a..bfa89558e3b 100644 --- a/mindspore/ccsrc/fl/server/iteration.cc +++ b/mindspore/ccsrc/fl/server/iteration.cc @@ -588,6 +588,7 @@ void Iteration::Next(bool is_iteration_valid, const std::string &reason) { if (LocalMetaStore::GetInstance().verifyAggregationFeatureMap(feature_map)) { ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model); + ModelStore::GetInstance().StoreCompressModelByIterNum(iteration_num_, model); iteration_result_ = IterationResult::kSuccess; MS_LOG(INFO) << "Iteration " << iteration_num_ << " is successfully finished."; } else { @@ -599,6 +600,7 @@ void Iteration::Next(bool is_iteration_valid, const std::string &reason) { size_t latest_iter_num = iter_to_model.rbegin()->first; const auto &model = ModelStore::GetInstance().GetModelByIterNum(latest_iter_num); ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model); + ModelStore::GetInstance().StoreCompressModelByIterNum(iteration_num_, model); iteration_result_ = IterationResult::kFail; MS_LOG(WARNING) << "Iteration " << iteration_num_ << " is invalid. Reason: " << reason; } diff --git a/mindspore/ccsrc/fl/server/kernel/round/get_model_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/get_model_kernel.cc index daa2589d481..b5726843694 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/get_model_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/get_model_kernel.cc @@ -92,7 +92,7 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, return; } auto next_req_time = LocalMetaStore::GetInstance().value(kCtxIterationNextRequestTimestamp); - std::map feature_maps; + std::map feature_maps = {}; size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num(); size_t get_model_iter = IntToSize(get_model_req->iteration()); const auto &iter_to_model = ModelStore::GetInstance().iteration_to_model(); @@ -110,6 +110,7 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize()); return; } + IncreaseAcceptClientNum(); auto real_get_model_iter = get_model_iter; if (iter_to_model.count(get_model_iter) == 0) { @@ -118,12 +119,37 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, << " is invalid. Current iteration is " << std::to_string(current_iter); real_get_model_iter = latest_iter_num; } - auto cache = ModelStore::GetInstance().GetModelResponseCache(name_, current_iter, real_get_model_iter); + auto download_compress_types = get_model_req->download_compress_types(); + schema::CompressType compressType = + mindspore::fl::compression::CompressExecutor::GetInstance().GetCompressType(download_compress_types); + std::string compress_type; + if (compressType == schema::CompressType_QUANT) { + compress_type = kQuant; + } else { + compress_type = kNoCompress; + } + auto cache = ModelStore::GetInstance().GetModelResponseCache(name_, current_iter, real_get_model_iter, compress_type); if (cache == nullptr) { - feature_maps = ModelStore::GetInstance().GetModelByIterNum(real_get_model_iter); + // Only download compress weights if client support. + std::map compress_feature_maps = {}; + if (compressType == schema::CompressType_NO_COMPRESS) { + feature_maps = ModelStore::GetInstance().GetModelByIterNum(real_get_model_iter); + } else { + auto compressExecutor = mindspore::fl::compression::CompressExecutor::GetInstance(); + if (compressExecutor.EnableCompressWeight(compressType)) { + const auto &iter_to_compress_model = ModelStore::GetInstance().iteration_to_compress_model(); + if (iter_to_compress_model.count(get_model_iter) == 0) { + MS_LOG(DEBUG) << "The iteration of GetCompressModel request " << std::to_string(get_model_iter) + << " is invalid. Current iteration is " << std::to_string(current_iter); + compress_feature_maps = ModelStore::GetInstance().GetCompressModelByIterNum(latest_iter_num, compressType); + } else { + compress_feature_maps = ModelStore::GetInstance().GetCompressModelByIterNum(get_model_iter, compressType); + } + } + } BuildGetModelRsp(fbb, schema::ResponseCode_SUCCEED, "Get model for iteration " + std::to_string(get_model_iter), - current_iter, feature_maps, std::to_string(next_req_time)); - cache = ModelStore::GetInstance().StoreModelResponseCache(name_, current_iter, real_get_model_iter, + current_iter, feature_maps, std::to_string(next_req_time), compressType, compress_feature_maps); + cache = ModelStore::GetInstance().StoreModelResponseCache(name_, current_iter, real_get_model_iter, compress_type, fbb->GetBufferPointer(), fbb->GetSize()); if (cache == nullptr) { SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize()); @@ -131,7 +157,7 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, } } SendResponseMsgInference(message, cache->data(), cache->size(), ModelStore::GetInstance().RelModelResponseCache); - MS_LOG(DEBUG) << "GetModel last iteratin is valid or not: " << Iteration::GetInstance().is_last_iteration_valid() + MS_LOG(DEBUG) << "GetModel last iteration is valid or not: " << Iteration::GetInstance().is_last_iteration_valid() << ", next request time is " << next_req_time << ", current iteration is " << current_iter; return; } @@ -139,7 +165,8 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, void GetModelKernel::BuildGetModelRsp(const std::shared_ptr &fbb, const schema::ResponseCode retcode, const std::string &reason, const size_t iter, const std::map &feature_maps, - const std::string ×tamp) { + const std::string ×tamp, const schema::CompressType &compressType, + const std::map &compress_feature_maps) { if (fbb == nullptr) { MS_LOG(ERROR) << "Input fbb is nullptr."; return; @@ -156,12 +183,40 @@ void GetModelKernel::BuildGetModelRsp(const std::shared_ptr &fbb, con } auto fbs_feature_maps_vector = fbb->CreateVector(fbs_feature_maps); + // construct compress feature maps with fbs + std::vector> fbs_compress_feature_maps; + for (const auto &compress_feature_map : compress_feature_maps) { + if (compress_feature_map.first.find(kMinVal) != string::npos || + compress_feature_map.first.find(kMaxVal) != string::npos) { + continue; + } + auto fbs_compress_weight_fullname = fbb->CreateString(compress_feature_map.first); + auto fbs_compress_weight_data = fbb->CreateVector(reinterpret_cast(compress_feature_map.second->addr), + compress_feature_map.second->size / sizeof(int8_t)); + + const std::string min_val_name = compress_feature_map.first + "." + kMinVal; + const std::string max_val_name = compress_feature_map.first + "." + kMaxVal; + + const AddressPtr min_val_ptr = compress_feature_maps.at(min_val_name); + const AddressPtr max_val_ptr = compress_feature_maps.at(max_val_name); + + float *fbs_min_val_ptr = reinterpret_cast(min_val_ptr->addr); + float *fbs_max_val_ptr = reinterpret_cast(max_val_ptr->addr); + auto fbs_compress_feature_map = schema::CreateCompressFeatureMap( + *(fbb.get()), fbs_compress_weight_fullname, fbs_compress_weight_data, *fbs_min_val_ptr, *fbs_max_val_ptr); + + fbs_compress_feature_maps.push_back(fbs_compress_feature_map); + } + auto fbs_compress_feature_maps_vector = fbb->CreateVector(fbs_compress_feature_maps); + schema::ResponseGetModelBuilder rsp_get_model_builder(*(fbb.get())); rsp_get_model_builder.add_retcode(static_cast(retcode)); rsp_get_model_builder.add_reason(fbs_reason); rsp_get_model_builder.add_iteration(static_cast(iter)); rsp_get_model_builder.add_feature_map(fbs_feature_maps_vector); rsp_get_model_builder.add_timestamp(fbs_timestamp); + rsp_get_model_builder.add_download_compress_type(compressType); + rsp_get_model_builder.add_compress_feature_map(fbs_compress_feature_maps_vector); auto rsp_get_model = rsp_get_model_builder.Finish(); fbb->Finish(rsp_get_model); return; diff --git a/mindspore/ccsrc/fl/server/kernel/round/get_model_kernel.h b/mindspore/ccsrc/fl/server/kernel/round/get_model_kernel.h index 15379a8abbe..31039edc5f3 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/get_model_kernel.h +++ b/mindspore/ccsrc/fl/server/kernel/round/get_model_kernel.h @@ -25,6 +25,7 @@ #include "fl/server/executor.h" #include "fl/server/kernel/round/round_kernel.h" #include "fl/server/kernel/round/round_kernel_factory.h" +#include "fl/compression/encode_executor.h" namespace mindspore { namespace fl { @@ -44,7 +45,9 @@ class GetModelKernel : public RoundKernel { void GetModel(const schema::RequestGetModel *get_model_req, const std::shared_ptr &message); void BuildGetModelRsp(const std::shared_ptr &fbb, const schema::ResponseCode retcode, const std::string &reason, const size_t iter, - const std::map &feature_maps, const std::string ×tamp); + const std::map &feature_maps, const std::string ×tamp, + const schema::CompressType &compressType = schema::CompressType_NO_COMPRESS, + const std::map &compress_feature_maps = {}); // The executor is for getting model for getModel request. Executor *executor_; diff --git a/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc index ff4210edb5b..bb5915f09e1 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc @@ -126,10 +126,19 @@ bool StartFLJobKernel::Launch(const uint8_t *req_data, size_t len, IncreaseAcceptClientNum(); auto curr_iter_num = LocalMetaStore::GetInstance().curr_iter_num(); auto last_iteration = curr_iter_num - 1; - auto cache = ModelStore::GetInstance().GetModelResponseCache(name_, curr_iter_num, last_iteration); + auto download_compress_types = start_fl_job_req->download_compress_types(); + schema::CompressType compressType = + mindspore::fl::compression::CompressExecutor::GetInstance().GetCompressType(download_compress_types); + std::string compress_type; + if (compressType == schema::CompressType_QUANT) { + compress_type = kQuant; + } else { + compress_type = kNoCompress; + } + auto cache = ModelStore::GetInstance().GetModelResponseCache(name_, curr_iter_num, last_iteration, compress_type); if (cache == nullptr) { - StartFLJob(fbb); - cache = ModelStore::GetInstance().StoreModelResponseCache(name_, curr_iter_num, last_iteration, + StartFLJob(fbb, device_meta, start_fl_job_req); + cache = ModelStore::GetInstance().StoreModelResponseCache(name_, curr_iter_num, last_iteration, compress_type, fbb->GetBufferPointer(), fbb->GetSize()); if (cache == nullptr) { SendResponseMsg(message, fbb->GetBufferPointer(), fbb->GetSize()); @@ -303,22 +312,40 @@ ResultCode StartFLJobKernel::CountForStartFLJob(const std::shared_ptr return ResultCode::kSuccess; } -void StartFLJobKernel::StartFLJob(const std::shared_ptr &fbb) { +void StartFLJobKernel::StartFLJob(const std::shared_ptr &fbb, const DeviceMeta &, + const schema::RequestFLJob *start_fl_job_req) { size_t last_iteration = LocalMetaStore::GetInstance().curr_iter_num() - 1; - auto feature_maps = ModelStore::GetInstance().GetModelByIterNum(last_iteration); - if (feature_maps.empty()) { - MS_LOG(WARNING) << "The feature map for startFLJob is empty."; + + std::map feature_maps = {}; + std::map compress_feature_maps = {}; + + // Only download compress weights if client support. + auto download_compress_types = start_fl_job_req->download_compress_types(); + schema::CompressType compressType = + mindspore::fl::compression::CompressExecutor::GetInstance().GetCompressType(download_compress_types); + if (compressType == schema::CompressType_NO_COMPRESS) { + feature_maps = ModelStore::GetInstance().GetModelByIterNum(last_iteration); + if (feature_maps.empty()) { + MS_LOG(WARNING) << "The feature map for startFLJob is empty."; + } + } else { + if (mindspore::fl::compression::CompressExecutor::GetInstance().EnableCompressWeight(compressType)) { + compress_feature_maps = ModelStore::GetInstance().GetCompressModelByIterNum(last_iteration, compressType); + } } + BuildStartFLJobRsp(fbb, schema::ResponseCode_SUCCEED, "success", true, std::to_string(LocalMetaStore::GetInstance().value(kCtxIterationNextRequestTimestamp)), - feature_maps); + feature_maps, compressType, compress_feature_maps); return; } void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr &fbb, const schema::ResponseCode retcode, const std::string &reason, const bool is_selected, const std::string &next_req_time, - std::map feature_maps) { + const std::map &feature_maps, + const schema::CompressType &compressType, + const std::map &compress_feature_maps) { if (fbb == nullptr) { MS_LOG(WARNING) << "Input fbb is nullptr."; return; @@ -350,6 +377,12 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr &fbb, auto cipher_public_params = schema::CreateCipherPublicParams(*fbb.get(), encrypt_type, pw_params, dp_params, ds_params); #endif + schema::CompressType upload_compress_type; + if (ps::PSContext::instance()->upload_compress_type() == kDiffSparseQuant) { + upload_compress_type = schema::CompressType_DIFF_SPARSE_QUANT; + } else { + upload_compress_type = schema::CompressType_NO_COMPRESS; + } schema::FLPlanBuilder fl_plan_builder(*(fbb.get())); fl_plan_builder.add_fl_name(fbs_fl_name); @@ -375,6 +408,33 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr &fbb, } auto fbs_feature_maps_vector = fbb->CreateVector(fbs_feature_maps); + // construct compress feature maps with fbs + std::vector> fbs_compress_feature_maps; + for (const auto &compress_feature_map : compress_feature_maps) { + if (compressType == schema::CompressType_QUANT) { + if (compress_feature_map.first.find(kMinVal) != string::npos || + compress_feature_map.first.find(kMaxVal) != string::npos) { + continue; + } + auto fbs_compress_weight_fullname = fbb->CreateString(compress_feature_map.first); + auto fbs_compress_weight_data = fbb->CreateVector(reinterpret_cast(compress_feature_map.second->addr), + compress_feature_map.second->size / sizeof(int8_t)); + + const std::string min_val_name = compress_feature_map.first + "." + kMinVal; + const std::string max_val_name = compress_feature_map.first + "." + kMaxVal; + + const AddressPtr min_val_ptr = compress_feature_maps.at(min_val_name); + const AddressPtr max_val_ptr = compress_feature_maps.at(max_val_name); + + float *fbs_min_val_ptr = reinterpret_cast(min_val_ptr->addr); + float *fbs_max_val_ptr = reinterpret_cast(max_val_ptr->addr); + auto fbs_compress_feature_map = schema::CreateCompressFeatureMap( + *(fbb.get()), fbs_compress_weight_fullname, fbs_compress_weight_data, *fbs_min_val_ptr, *fbs_max_val_ptr); + fbs_compress_feature_maps.push_back(fbs_compress_feature_map); + } + } + auto fbs_compress_feature_maps_vector = fbb->CreateVector(fbs_compress_feature_maps); + schema::ResponseFLJobBuilder rsp_fl_job_builder(*(fbb.get())); rsp_fl_job_builder.add_retcode(static_cast(retcode)); rsp_fl_job_builder.add_reason(fbs_reason); @@ -383,6 +443,10 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr &fbb, rsp_fl_job_builder.add_next_req_time(fbs_next_req_time); rsp_fl_job_builder.add_fl_plan_config(fbs_fl_plan); rsp_fl_job_builder.add_feature_map(fbs_feature_maps_vector); + rsp_fl_job_builder.add_download_compress_type(compressType); + rsp_fl_job_builder.add_compress_feature_map(fbs_compress_feature_maps_vector); + rsp_fl_job_builder.add_upload_compress_type(upload_compress_type); + rsp_fl_job_builder.add_upload_sparse_rate(ps::PSContext::instance()->upload_sparse_rate()); auto rsp_fl_job = rsp_fl_job_builder.Finish(); fbb->Finish(rsp_fl_job); return; diff --git a/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.h b/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.h index 1c29e0aed8a..01c90bdfe10 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.h +++ b/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.h @@ -25,6 +25,9 @@ #include "fl/server/executor.h" #include "fl/server/kernel/round/round_kernel.h" #include "fl/server/kernel/round/round_kernel_factory.h" +#include "schema/fl_job_generated.h" +#include "schema/cipher_generated.h" +#include "fl/compression/encode_executor.h" namespace mindspore { namespace fl { @@ -56,7 +59,8 @@ class StartFLJobKernel : public RoundKernel { // Distributed count service counts for startFLJob. ResultCode CountForStartFLJob(const std::shared_ptr &fbb, const schema::RequestFLJob *start_fl_job_req); - void StartFLJob(const std::shared_ptr &fbb); + void StartFLJob(const std::shared_ptr &fbb, const DeviceMeta &device_meta, + const schema::RequestFLJob *start_fl_job_req); bool JudgeFLJobCert(const std::shared_ptr &fbb, const schema::RequestFLJob *start_fl_job_req); @@ -65,7 +69,9 @@ class StartFLJobKernel : public RoundKernel { // Build response for startFLJob round no matter success or failure. void BuildStartFLJobRsp(const std::shared_ptr &fbb, const schema::ResponseCode retcode, const std::string &reason, const bool is_selected, const std::string &next_req_time, - std::map feature_maps = {}); + const std::map &feature_maps = {}, + const schema::CompressType &compressType = schema::CompressType_NO_COMPRESS, + const std::map &compress_feature_maps = {}); // The executor is for getting the initial model for startFLJob request. Executor *executor_; diff --git a/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc index 9c9b8700baa..631c89cedf4 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc @@ -201,23 +201,27 @@ ResultCode UpdateModelKernel::VerifyUpdateModel(const schema::RequestUpdateModel } std::unordered_map feature_map; - auto upload_feature_map = update_model_req->feature_map(); - MS_ERROR_IF_NULL_W_RET_VAL(upload_feature_map, ResultCode::kFail); - for (uint32_t i = 0; i < upload_feature_map->size(); i++) { - const auto &item = upload_feature_map->Get(i); - MS_ERROR_IF_NULL_W_RET_VAL(item, ResultCode::kFail); - MS_ERROR_IF_NULL_W_RET_VAL(item->weight_fullname(), ResultCode::kFail); - MS_ERROR_IF_NULL_W_RET_VAL(item->data(), ResultCode::kFail); + if (ps::PSContext::instance()->upload_compress_type() != kDiffSparseQuant) { + auto upload_feature_map = update_model_req->feature_map(); + MS_ERROR_IF_NULL_W_RET_VAL(upload_feature_map, ResultCode::kFail); + for (uint32_t i = 0; i < upload_feature_map->size(); i++) { + const auto &item = upload_feature_map->Get(i); + MS_ERROR_IF_NULL_W_RET_VAL(item, ResultCode::kFail); + MS_ERROR_IF_NULL_W_RET_VAL(item->weight_fullname(), ResultCode::kFail); + MS_ERROR_IF_NULL_W_RET_VAL(item->data(), ResultCode::kFail); - std::string weight_full_name = item->weight_fullname()->str(); - size_t weight_size = item->data()->size() * sizeof(float); - feature_map[weight_full_name] = weight_size; + std::string weight_full_name = item->weight_fullname()->str(); + size_t weight_size = item->data()->size() * sizeof(float); + feature_map[weight_full_name] = weight_size; + } } bool verifyFeatureMapIsSuccess; if (ps::PSContext::instance()->encrypt_type() == ps::kDSEncryptType && update_model_req->sign() != 0) { MS_ERROR_IF_NULL_W_RET_VAL(update_model_req->index_array(), ResultCode::kFail); verifyFeatureMapIsSuccess = VerifySignDSFeatureMap(feature_map, update_model_req); + } else if (ps::PSContext::instance()->upload_compress_type() == kDiffSparseQuant) { + verifyFeatureMapIsSuccess = VerifyUploadCompressFeatureMap(update_model_req); } else { verifyFeatureMapIsSuccess = LocalMetaStore::GetInstance().verifyAggregationFeatureMap(feature_map); } @@ -280,6 +284,45 @@ bool UpdateModelKernel::VerifySignDSFeatureMap(const std::unordered_mapupload_sparse_rate(); + if (upload_sparse_rate != ps::PSContext::instance()->upload_sparse_rate()) { + MS_LOG(WARNING) << "The upload_sparse_rate must be equal to the setting in context."; + return false; + } + auto fbs_name_vec = update_model_req->name_vec(); + if (fbs_name_vec == nullptr) { + MS_LOG(WARNING) << "The name_vec is null."; + return false; + } + if (fbs_name_vec->size() == 0) { + MS_LOG(WARNING) << "The size of name_vec must be larger than 0."; + return false; + } + if (fbs_name_vec->size() > aggregation_feature_map_.size()) { + MS_LOG(WARNING) << "The size of name_vec must be smaller than model in server."; + return false; + } + for (size_t i = 0; i < fbs_name_vec->size(); ++i) { + std::string name = fbs_name_vec->Get(i)->str(); + if (aggregation_feature_map_.count(name) == 0) { + MS_LOG(WARNING) << "The upload name: " << name << " is not in model in server."; + return false; + } + } + auto fbs_compress_feature_map = update_model_req->compress_feature_map(); + if (fbs_compress_feature_map == nullptr) { + MS_LOG(WARNING) << "The upload compress feature map is null."; + return false; + } + if (fbs_compress_feature_map->size() == 0) { + MS_LOG(WARNING) << "The upload compress feature map is empty."; + return false; + } + return true; +} + ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_model_req, const std::shared_ptr &fbb, const DeviceMeta &device_meta) { MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kFail); @@ -292,6 +335,8 @@ ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *upda std::map feature_map; if (ps::PSContext::instance()->encrypt_type() == ps::kDSEncryptType) { feature_map = ParseSignDSFeatureMap(update_model_req, data_size, &weight_map); + } else if (ps::PSContext::instance()->upload_compress_type() == kDiffSparseQuant) { + feature_map = ParseUploadCompressFeatureMap(update_model_req, data_size, &weight_map); } else { feature_map = ParseFeatureMap(update_model_req); } @@ -397,6 +442,89 @@ std::map UpdateModelKernel::ParseSignDSFeatureMap( return feature_map; } +std::map UpdateModelKernel::ParseUploadCompressFeatureMap( + const schema::RequestUpdateModel *update_model_req, size_t data_size, + std::map> *weight_map) { + MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, {}); + std::map feature_map; + schema::CompressType upload_compress_type = update_model_req->upload_compress_type(); + upload_compress_type = + mindspore::fl::compression::DecodeExecutor::GetInstance().GetCompressType(upload_compress_type); + MS_LOG(INFO) << "This schema upload compress type is: " << upload_compress_type; + if (upload_compress_type != schema::CompressType_NO_COMPRESS) { + MS_LOG(INFO) << "This upload compress type is DIFF_SPARSE_QUANT."; + feature_map = DecodeFeatureMap(weight_map, update_model_req, upload_compress_type, data_size); + return feature_map; + } + MS_LOG(INFO) << "This upload compress type is NO_COMPRESS."; + // Some clients upload origin weights. + auto fbs_feature_map = update_model_req->feature_map(); + MS_ERROR_IF_NULL_W_RET_VAL(fbs_feature_map, feature_map); + for (uint32_t i = 0; i < fbs_feature_map->size(); i++) { + std::string weight_full_name = fbs_feature_map->Get(i)->weight_fullname()->str(); + float *weight_data = const_cast(fbs_feature_map->Get(i)->data()->data()); + size_t weight_size = fbs_feature_map->Get(i)->data()->size() * sizeof(float); + UploadData upload_data; + upload_data[kNewWeight].addr = weight_data; + upload_data[kNewWeight].size = weight_size; + feature_map[weight_full_name] = upload_data; + } + return feature_map; +} + +std::map UpdateModelKernel::DecodeFeatureMap( + std::map> *weight_map, const schema::RequestUpdateModel *update_model_req, + schema::CompressType upload_compress_type, size_t data_size) { + std::map feature_map; + + // Get and set decode hyper parameters. + auto seed = update_model_req->iteration(); + MS_LOG(INFO) << "The seed for compression is: " << seed; + auto upload_sparse_rate = update_model_req->upload_sparse_rate(); + MS_LOG(INFO) << "The upload_sparse_rate for compression is: " << upload_sparse_rate; + // Get name vector. + auto fbs_name_vec = update_model_req->name_vec(); + std::vector name_vec; + for (size_t i = 0; i < fbs_name_vec->size(); ++i) { + name_vec.emplace_back(fbs_name_vec->Get(i)->str()); + } + + // Parameter process for decode. + auto fbs_compress_feature_map = update_model_req->compress_feature_map(); + std::vector compress_feature_maps; + for (size_t i = 0; i < fbs_compress_feature_map->size(); ++i) { + mindspore::fl::compression::CompressFeatureMap compress_feature_map; + int8_t *compress_weight_data = const_cast(fbs_compress_feature_map->Get(i)->compress_data()->data()); + size_t compress_weight_size = fbs_compress_feature_map->Get(i)->compress_data()->size(); + MS_LOG(INFO) << "The compress weight size: " << compress_weight_size; + for (size_t j = 0; j < compress_weight_size; ++j) { + compress_feature_map.compress_data.emplace_back(compress_weight_data[j]); + } + compress_feature_map.min_val = fbs_compress_feature_map->Get(i)->min_val(); + compress_feature_map.max_val = fbs_compress_feature_map->Get(i)->max_val(); + MS_LOG(INFO) << "Min value: " << compress_feature_map.min_val; + MS_LOG(INFO) << "Max value: " << compress_feature_map.max_val; + compress_feature_maps.emplace_back(compress_feature_map); + } + + // Decode. + bool status = mindspore::fl::compression::DecodeExecutor::GetInstance().Decode( + weight_map, compress_feature_maps, upload_compress_type, upload_sparse_rate, seed, name_vec, data_size); + if (status) { + for (size_t i = 0; i < name_vec.size(); ++i) { + std::string weight_full_name = name_vec[i]; + size_t weight_size = (*weight_map)[weight_full_name].size() * sizeof(float); + UploadData upload_data; + upload_data[kNewWeight].addr = (*weight_map)[weight_full_name].data(); + upload_data[kNewWeight].size = weight_size; + feature_map[weight_full_name] = upload_data; + } + return feature_map; + } + MS_LOG(WARNING) << "Decode failed!"; + return feature_map; +} + ResultCode UpdateModelKernel::CountForAggregation(const std::string &req_fl_id) { std::string count_reason = ""; if (!DistributedCountService::GetInstance().Count(kCountForAggregation, req_fl_id, &count_reason)) { diff --git a/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.h b/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.h index df979265190..0be96a840da 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.h +++ b/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.h @@ -30,6 +30,9 @@ #ifdef ENABLE_ARMOUR #include "fl/armour/cipher/cipher_meta_storage.h" #endif +#include "fl/compression/decode_executor.h" +#include "schema/fl_job_generated.h" +#include "schema/cipher_generated.h" namespace mindspore { namespace fl { @@ -64,8 +67,12 @@ class UpdateModelKernel : public RoundKernel { std::map ParseSignDSFeatureMap(const schema::RequestUpdateModel *update_model_req, size_t data_size, std::map> *weight_map); + std::map ParseUploadCompressFeatureMap( + const schema::RequestUpdateModel *update_model_req, size_t data_size, + std::map> *weight_map); bool VerifySignDSFeatureMap(const std::unordered_map &model, const schema::RequestUpdateModel *update_model_req); + bool VerifyUploadCompressFeatureMap(const schema::RequestUpdateModel *update_model_req); ResultCode CountForUpdateModel(const std::shared_ptr &fbb, const schema::RequestUpdateModel *update_model_req); sigVerifyResult VerifySignature(const schema::RequestUpdateModel *update_model_req); @@ -78,6 +85,11 @@ class UpdateModelKernel : public RoundKernel { // The time window of one iteration. size_t iteration_time_window_{0}; + + // Decode functions of compression. + std::map DecodeFeatureMap(std::map> *weight_map, + const schema::RequestUpdateModel *update_model_req, + schema::CompressType upload_compress_type, size_t data_size); }; } // namespace kernel } // namespace server diff --git a/mindspore/ccsrc/fl/server/memory_register.cc b/mindspore/ccsrc/fl/server/memory_register.cc index 4512211005e..0360eb222a9 100644 --- a/mindspore/ccsrc/fl/server/memory_register.cc +++ b/mindspore/ccsrc/fl/server/memory_register.cc @@ -44,6 +44,11 @@ void MemoryRegister::StoreCharArray(std::unique_ptr *array) { MS_ERROR_IF_NULL_WO_RET_VAL(array); char_arrays_.push_back(std::move(*array)); } + +void MemoryRegister::StoreFloat32(std::unique_ptr *param) { + MS_ERROR_IF_NULL_WO_RET_VAL(param); + float_params_.push_back(std::move(*param)); +} } // namespace server } // namespace fl } // namespace mindspore diff --git a/mindspore/ccsrc/fl/server/memory_register.h b/mindspore/ccsrc/fl/server/memory_register.h index b6e7ce8c094..92dff14699f 100644 --- a/mindspore/ccsrc/fl/server/memory_register.h +++ b/mindspore/ccsrc/fl/server/memory_register.h @@ -24,6 +24,7 @@ #include #include #include "fl/server/common.h" +#include "fl/compression/encode_executor.h" namespace mindspore { namespace fl { @@ -70,6 +71,25 @@ class MemoryRegister { return; } + template + void RegisterParameter(const std::string &name, std::unique_ptr *param, size_t size) { + MS_EXCEPTION_IF_NULL(param); + void *data = param->get(); + AddressPtr addressPtr = std::make_shared
(); + addressPtr->addr = data; + addressPtr->size = size; + if (typeid(T) == typeid(float)) { + auto float_param = CastUniqueParamPtr(param); + StoreFloat32(&float_param); + } else { + MS_LOG(ERROR) << "MemoryRegister does not support type " << typeid(T).name(); + return; + } + + RegisterAddressPtr(name, addressPtr); + return; + } + private: std::map addresses_; std::vector> float_arrays_; @@ -86,6 +106,15 @@ class MemoryRegister { std::unique_ptr CastUniquePtr(std::unique_ptr *array) { return std::unique_ptr{reinterpret_cast(array->release())}; } + + std::vector> float_params_; + + void StoreFloat32(std::unique_ptr *array); + + template + std::unique_ptr CastUniqueParamPtr(std::unique_ptr *param) { + return std::unique_ptr{reinterpret_cast(param->release())}; + } }; } // namespace server } // namespace fl diff --git a/mindspore/ccsrc/fl/server/model_store.cc b/mindspore/ccsrc/fl/server/model_store.cc index 39c8f6c9ed7..3d112491c3a 100644 --- a/mindspore/ccsrc/fl/server/model_store.cc +++ b/mindspore/ccsrc/fl/server/model_store.cc @@ -19,6 +19,7 @@ #include #include #include "fl/server/executor.h" +#include "pipeline/jit/parse/parse.h" #include "include/common/utils/python_adapter.h" namespace mindspore { @@ -33,6 +34,10 @@ void ModelStore::Initialize(uint32_t rank_id, uint32_t max_count) { max_model_count_ = max_count; initial_model_ = AssignNewModelMemory(); iteration_to_model_[kInitIterationNum] = initial_model_; + std::map model = Executor::GetInstance().GetModel(); + for (const auto &item : mindspore::fl::compression::kCompressTypeMap) { + iteration_to_compress_model_[kInitIterationNum][item.first] = AssignNewCompressModelMemory(item.first, model); + } model_size_ = ComputeModelSize(); MS_LOG(INFO) << "Model store checkpoint dir is: " << ps::PSContext::instance()->checkpoint_dir(); } @@ -101,6 +106,24 @@ std::map ModelStore::GetModelByIterNum(size_t iteration return model; } +std::map ModelStore::GetCompressModelByIterNum(size_t iteration, + schema::CompressType compressType) { + std::unique_lock lock(model_mtx_); + std::map compressModel = {}; + if (iteration_to_compress_model_.count(iteration) == 0) { + MS_LOG(ERROR) << "Compress Model for iteration " << iteration << " is not stored."; + return compressModel; + } + std::map> compress_model_map = + iteration_to_compress_model_[iteration]; + if (compress_model_map.count(compressType) == 0) { + MS_LOG(ERROR) << "Compress Model for compress type " << compressType << " is not stored."; + return compressModel; + } + compressModel = iteration_to_compress_model_[iteration][compressType]->addresses(); + return compressModel; +} + void ModelStore::Reset() { std::unique_lock lock(model_mtx_); initial_model_ = iteration_to_model_.rbegin()->second; @@ -114,6 +137,11 @@ const std::map> &ModelStore::iteration_t return iteration_to_model_; } +const std::map &ModelStore::iteration_to_compress_model() { + std::unique_lock lock(model_mtx_); + return iteration_to_compress_model_; +} + size_t ModelStore::model_size() const { return model_size_; } std::shared_ptr ModelStore::AssignNewModelMemory() { @@ -146,6 +174,86 @@ std::shared_ptr ModelStore::AssignNewModelMemory() { return memory_register; } +std::shared_ptr ModelStore::AssignNewCompressModelMemory( + schema::CompressType compressType, const std::map &model) { + if (model.empty()) { + MS_LOG(EXCEPTION) << "Model feature map is empty."; + return nullptr; + } + std::map> feature_maps; + for (auto &feature_map : model) { + auto weight_fullname = feature_map.first; + auto weight_data = reinterpret_cast(feature_map.second->addr); + std::vector weight_data_vector{weight_data, weight_data + feature_map.second->size / sizeof(float)}; + feature_maps[weight_fullname] = weight_data_vector; + } + + std::map compressWeights; + bool status = mindspore::fl::compression::CompressExecutor::GetInstance().construct_compress_weight( + &compressWeights, feature_maps, compressType); + if (!status) { + MS_LOG(ERROR) << "Encode failed!"; + return nullptr; + } + + // Assign new memory for the compress model. + std::shared_ptr memory_register = std::make_shared(); + MS_ERROR_IF_NULL_W_RET_VAL(memory_register, nullptr); + MS_LOG(INFO) << "Register compressWeight for compressType: " << schema::EnumNameCompressType(compressType); + + for (const auto &compressWeight : compressWeights) { + if (compressType == schema::CompressType_QUANT) { + std::string compress_weight_name = compressWeight.first; + std::string min_val_name = compress_weight_name + "." + kMinVal; + std::string max_val_name = compress_weight_name + "." + kMaxVal; + size_t compress_weight_size = compressWeight.second.compress_data_len * sizeof(int8_t); + auto compress_weight_data = std::make_unique(compress_weight_size); + auto src_data_size = compress_weight_size; + auto dst_data_size = compress_weight_size; + int ret = + memcpy_s(compress_weight_data.get(), dst_data_size, compressWeight.second.compress_data.data(), src_data_size); + if (ret != 0) { + MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; + return nullptr; + } + memory_register->RegisterArray(compress_weight_name, &compress_weight_data, compress_weight_size); + size_t float_size = 1; + auto min_val_ptr = std::make_unique(compressWeight.second.min_val); + auto max_val_ptr = std::make_unique(compressWeight.second.max_val); + + memory_register->RegisterParameter(min_val_name, &min_val_ptr, float_size); + memory_register->RegisterParameter(max_val_name, &max_val_ptr, float_size); + } + } + return memory_register; +} + +void ModelStore::StoreCompressModelByIterNum(size_t iteration, const std::map &new_model) { + std::unique_lock lock(model_mtx_); + if (iteration_to_compress_model_.count(iteration) != 0) { + MS_LOG(WARNING) << "Compress Model for iteration " << iteration << " is already stored"; + return; + } + if (new_model.empty()) { + MS_LOG(ERROR) << "Compress Model feature map is empty."; + return; + } + + iteration_to_compress_model_[iteration] = {}; + if (iteration_to_compress_model_.size() >= max_model_count_) { + auto compress_model_map = iteration_to_compress_model_.begin()->second; + compress_model_map.clear(); + (void)iteration_to_compress_model_.erase(iteration_to_compress_model_.begin()); + } + + for (const auto &item : mindspore::fl::compression::kCompressTypeMap) { + auto memory_register = AssignNewCompressModelMemory(item.first, new_model); + MS_ERROR_IF_NULL_WO_RET_VAL(memory_register); + iteration_to_compress_model_[iteration][item.first] = memory_register; + } + return; +} + size_t ModelStore::ComputeModelSize() { std::unique_lock lock(model_mtx_); if (iteration_to_model_.empty()) { @@ -179,13 +287,15 @@ void ModelStore::RelModelResponseCache(const void *data, size_t datalen, void *e std::shared_ptr> ModelStore::GetModelResponseCache(const string &round_name, size_t cur_iteration_num, - size_t model_iteration_num) { + size_t model_iteration_num, + const std::string &compress_type) { std::unique_lock lock(model_response_cache_lock_); - auto it = std::find_if(model_response_cache_.begin(), model_response_cache_.end(), - [&round_name, cur_iteration_num, model_iteration_num](const HttpResponseModelCache &item) { - return item.round_name == round_name && item.cur_iteration_num == cur_iteration_num && - item.model_iteration_num == model_iteration_num; - }); + auto it = std::find_if( + model_response_cache_.begin(), model_response_cache_.end(), + [&round_name, cur_iteration_num, model_iteration_num, &compress_type](const HttpResponseModelCache &item) { + return item.round_name == round_name && item.cur_iteration_num == cur_iteration_num && + item.model_iteration_num == model_iteration_num && item.compress_type == compress_type; + }); if (it == model_response_cache_.end()) { return nullptr; } @@ -196,14 +306,16 @@ std::shared_ptr> ModelStore::GetModelResponseCache(const st std::shared_ptr> ModelStore::StoreModelResponseCache(const string &round_name, size_t cur_iteration_num, - size_t model_iteration_num, const void *data, - size_t datalen) { + size_t model_iteration_num, + const std::string &compress_type, + const void *data, size_t datalen) { std::unique_lock lock(model_response_cache_lock_); - auto it = std::find_if(model_response_cache_.begin(), model_response_cache_.end(), - [&round_name, cur_iteration_num, model_iteration_num](const HttpResponseModelCache &item) { - return item.round_name == round_name && item.cur_iteration_num == cur_iteration_num && - item.model_iteration_num == model_iteration_num; - }); + auto it = std::find_if( + model_response_cache_.begin(), model_response_cache_.end(), + [&round_name, cur_iteration_num, model_iteration_num, &compress_type](const HttpResponseModelCache &item) { + return item.round_name == round_name && item.cur_iteration_num == cur_iteration_num && + item.model_iteration_num == model_iteration_num && item.compress_type == compress_type; + }); if (it != model_response_cache_.end()) { it->reference_count += 1; total_add_reference_count += 1; @@ -223,6 +335,7 @@ std::shared_ptr> ModelStore::StoreModelResponseCache(const item.round_name = round_name; item.cur_iteration_num = cur_iteration_num; item.model_iteration_num = model_iteration_num; + item.compress_type = compress_type; item.cache = cache; item.reference_count = 1; total_add_reference_count += 1; diff --git a/mindspore/ccsrc/fl/server/model_store.h b/mindspore/ccsrc/fl/server/model_store.h index 9b890e4fa38..c8c4da7c01b 100644 --- a/mindspore/ccsrc/fl/server/model_store.h +++ b/mindspore/ccsrc/fl/server/model_store.h @@ -25,6 +25,7 @@ #include "fl/server/common.h" #include "fl/server/memory_register.h" #include "fl/server/executor.h" +#include "fl/compression/encode_executor.h" #include "fl/server/local_meta_store.h" namespace mindspore { @@ -36,6 +37,9 @@ constexpr size_t kInitIterationNum = 0; // The initial iteration number after ModelStore is reset. constexpr size_t kResetInitialIterNum = 1; +// The compress type map. +using CompressTypeMap = std::map>; + // Server framework use ModelStore to store and query models. // ModelStore stores multiple models because worker could get models of the previous iterations. class ModelStore { @@ -64,15 +68,25 @@ class ModelStore { // Returns the model size, which could be calculated at the initializing phase. size_t model_size() const; + // Get compress model of the given iteration. + std::map GetCompressModelByIterNum(size_t iteration, schema::CompressType compressType); + + const std::map>> + &iteration_to_compress_model(); + + void StoreCompressModelByIterNum(size_t iteration, const std::map &new_model); + static void RelModelResponseCache(const void *data, size_t datalen, void *extra); std::shared_ptr> GetModelResponseCache(const std::string &round_name, size_t cur_iteration_num, - size_t model_iteration_num); + size_t model_iteration_num, + const std::string &compress_type); std::shared_ptr> StoreModelResponseCache(const std::string &round_name, size_t cur_iteration_num, - size_t model_iteration_num, const void *data, + size_t model_iteration_num, + const std::string &compress_type, const void *data, size_t datalen); private: - ModelStore() : max_model_count_(0), model_size_(0), iteration_to_model_({}) {} + ModelStore() : max_model_count_(0), model_size_(0), iteration_to_model_({}), iteration_to_compress_model_({}) {} ~ModelStore() = default; ModelStore(const ModelStore &) = delete; ModelStore &operator=(const ModelStore &) = delete; @@ -83,6 +97,9 @@ class ModelStore { // model_size_. std::shared_ptr AssignNewModelMemory(); + std::shared_ptr AssignNewCompressModelMemory(schema::CompressType compressType, + const std::map &model); + // Calculate the model size. This method should be called after iteration_to_model_ is initialized. size_t ComputeModelSize(); @@ -95,12 +112,17 @@ class ModelStore { // The number of all models stored is max_model_count_. std::mutex model_mtx_; std::map> iteration_to_model_; + + // iteration -> (compress type -> compress model) + std::map>> iteration_to_compress_model_; + uint32_t rank_id_; struct HttpResponseModelCache { std::string round_name; // startFlJob, getModel size_t cur_iteration_num = 0; size_t model_iteration_num = 0; + std::string compress_type = kNoCompress; size_t reference_count = 0; std::shared_ptr> cache = nullptr; }; diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc index ab3fa3e875a..0833728d98a 100644 --- a/mindspore/ccsrc/pipeline/jit/init.cc +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -507,6 +507,12 @@ PYBIND11_MODULE(_c_expression, m) { .def("set_global_iteration_time_window", &PSContext::set_global_iteration_time_window, "Set global iteration time window.") .def("global_iteration_time_window", &PSContext::global_iteration_time_window, "Get global iteration time window.") + .def("set_upload_compress_type", &PSContext::set_upload_compress_type, "Set upload compress type.") + .def("upload_compress_type", &PSContext::upload_compress_type, "Get upload compress type.") + .def("set_upload_sparse_rate", &PSContext::set_upload_sparse_rate, "Set upload sparse rate.") + .def("upload_sparse_rate", &PSContext::upload_sparse_rate, "Get upload sparse rate.") + .def("set_download_compress_type", &PSContext::set_download_compress_type, "Set download compress type.") + .def("download_compress_type", &PSContext::download_compress_type, "Get download compress type.") .def("set_checkpoint_dir", &PSContext::set_checkpoint_dir, "Set server checkpoint directory.") .def("checkpoint_dir", &PSContext::checkpoint_dir, "Server checkpoint directory."); (void)m.def("_encrypt", &mindspore::pipeline::PyEncrypt, "Encrypt the data."); diff --git a/mindspore/ccsrc/ps/ps_context.cc b/mindspore/ccsrc/ps/ps_context.cc index 089f991445d..8f8fc1d66f6 100644 --- a/mindspore/ccsrc/ps/ps_context.cc +++ b/mindspore/ccsrc/ps/ps_context.cc @@ -550,6 +550,19 @@ void PSContext::set_global_iteration_time_window(const uint64_t &global_iteratio uint64_t PSContext::global_iteration_time_window() const { return global_iteration_time_window_; } +void PSContext::set_upload_compress_type(const std::string &upload_compress_type) { + upload_compress_type_ = upload_compress_type; +} +std::string PSContext::upload_compress_type() const { return upload_compress_type_; } + +void PSContext::set_upload_sparse_rate(float upload_sparse_rate) { upload_sparse_rate_ = upload_sparse_rate; } +float PSContext::upload_sparse_rate() const { return upload_sparse_rate_; } + +void PSContext::set_download_compress_type(const std::string &download_compress_type) { + download_compress_type_ = download_compress_type; +} +std::string PSContext::download_compress_type() const { return download_compress_type_; } + std::string PSContext::checkpoint_dir() const { return checkpoint_dir_; } void PSContext::set_checkpoint_dir(const std::string &checkpoint_dir) { checkpoint_dir_ = checkpoint_dir; } diff --git a/mindspore/ccsrc/ps/ps_context.h b/mindspore/ccsrc/ps/ps_context.h index 4746953ed0f..7027d73fd7a 100644 --- a/mindspore/ccsrc/ps/ps_context.h +++ b/mindspore/ccsrc/ps/ps_context.h @@ -40,6 +40,7 @@ constexpr char kPWEncryptType[] = "PW_ENCRYPT"; constexpr char kStablePWEncryptType[] = "STABLE_PW_ENCRYPT"; constexpr char kNotEncryptType[] = "NOT_ENCRYPT"; constexpr char kDSEncryptType[] = "SIGNDS"; +constexpr char kNoCompressType[] = "NO_COMPRESS"; // Use binary data to represent federated learning server's context so that we can judge which round resets the // iteration. From right to left, each bit stands for: @@ -230,6 +231,15 @@ class PSContext { void set_global_iteration_time_window(const uint64_t &global_iteration_time_window); uint64_t global_iteration_time_window() const; + void set_upload_compress_type(const std::string &upload_compress_type); + std::string upload_compress_type() const; + + void set_upload_sparse_rate(float upload_sparse_rate); + float upload_sparse_rate() const; + + void set_download_compress_type(const std::string &download_compress_type); + std::string download_compress_type() const; + std::string checkpoint_dir() const; void set_checkpoint_dir(const std::string &checkpoint_dir); @@ -286,6 +296,9 @@ class PSContext { server_password_(""), http_url_prefix_(""), global_iteration_time_window_(3600000), + upload_compress_type_(kNoCompressType), + upload_sparse_rate_(0.4f), + download_compress_type_(kNoCompressType), checkpoint_dir_("") {} bool ps_enabled_; bool is_worker_; @@ -419,6 +432,13 @@ class PSContext { // The time window of startFLJob round in millisecond. uint64_t global_iteration_time_window_; + + // Hyper parameters for upload compression. + std::string upload_compress_type_; + float upload_sparse_rate_; + // Hyper parameters for download compression. + std::string download_compress_type_; + // directory of server checkpoint std::string checkpoint_dir_; }; diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLLiteClient.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLLiteClient.java index 0d5741436e4..9e04a92523a 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLLiteClient.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLLiteClient.java @@ -105,6 +105,16 @@ public class FLLiteClient { batchSize = flPlan.miniBatch(); String serverMod = flPlan.serverMode(); localFLParameter.setServerMod(serverMod); + // Get and set hyper parameters for compression. + byte uploadCompressType = flJob.uploadCompressType(); + LOGGER.info(Common.addTag("[startFLJob] [compression] uploadCompressType: " + uploadCompressType)); + localFLParameter.setUploadCompressType(uploadCompressType); + float uploadSparseRate = flJob.uploadSparseRate(); + LOGGER.info(Common.addTag("[startFLJob] [compression] uploadSparseRate: " + uploadSparseRate)); + localFLParameter.setUploadSparseRatio(uploadSparseRate); + int seed = flJob.iteration(); + LOGGER.info(Common.addTag("[startFLJob] [compression] seed: " + seed)); + localFLParameter.setSeed(seed); if (Common.checkFLName(flParameter.getFlName())) { deprecatedSetBatchSize(batchSize); } else { @@ -446,7 +456,7 @@ public class FLLiteClient { return status; } - private Map getFeatureMap() { + public Map getFeatureMap() { Map featureMap = new HashMap<>(); if (Common.checkFLName(flParameter.getFlName())) { featureMap = deprecatedGetFeatureMap(); @@ -530,8 +540,7 @@ public class FLLiteClient { localFLParameter.getEncryptLevel().toString() + "> : " + curStatus)); return curStatus; case DP_ENCRYPT: - // get the feature map before train - oldFeatureMap = getFeatureMap(); + oldFeatureMap = localFLParameter.getOldFeatureMap(); curStatus = secureProtocol.setDPParameter(iteration, dpEps, dpDelta, dpNormClipAdapt, oldFeatureMap); retCode = ResponseCode.SUCCEED; if (curStatus != FLClientStatus.SUCCESS) { @@ -542,8 +551,7 @@ public class FLLiteClient { LOGGER.info(Common.addTag("[Encrypt] set parameters for DP_ENCRYPT!")); return FLClientStatus.SUCCESS; case SIGNDS: - // get the feature map before train - oldFeatureMap = getFeatureMap(); + oldFeatureMap = localFLParameter.getOldFeatureMap(); curStatus = secureProtocol.setDSParameter(signK, signEps, signThrRatio, signGlobalLr, signDimOut, oldFeatureMap); retCode = ResponseCode.SUCCEED; if (curStatus != FLClientStatus.SUCCESS) { diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLParameter.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLParameter.java index 9cf3b5a8d16..9ef777c00b9 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLParameter.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLParameter.java @@ -18,7 +18,9 @@ package com.mindspore.flclient; import static com.mindspore.flclient.LocalFLParameter.ALBERT; +import com.mindspore.flclient.compression.CompressMode; import com.mindspore.flclient.model.RunType; +import mindspore.schema.CompressType; import java.util.ArrayList; import java.util.Arrays; @@ -603,6 +605,16 @@ public class FLParameter { this.batchSize = batchSize; } + public byte[] getDownloadCompressTypes() { + byte[] downloadCompressTypes = new byte[CompressMode.COMPRESS_TYPE_MAP.size()]; + int index = 0; + for (byte downloadCompressType : CompressMode.COMPRESS_TYPE_MAP.keySet()) { + downloadCompressTypes[index] = downloadCompressType; + index += 1; + } + return downloadCompressTypes; + } + public int[][] getInputShape() { return inputShape; } diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/GetModel.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/GetModel.java index 8af179b122c..1cfd426d8da 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/GetModel.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/GetModel.java @@ -18,6 +18,7 @@ package com.mindspore.flclient; import com.google.flatbuffers.FlatBufferBuilder; +import com.mindspore.flclient.compression.DecodeExecutor; import com.mindspore.flclient.model.AlInferBert; import com.mindspore.flclient.model.AlTrainBert; import com.mindspore.flclient.model.Client; @@ -27,11 +28,9 @@ import com.mindspore.flclient.model.SessionUtil; import com.mindspore.flclient.model.Status; import com.mindspore.flclient.model.TrainLenet; -import mindspore.schema.FeatureMap; -import mindspore.schema.RequestGetModel; -import mindspore.schema.ResponseCode; -import mindspore.schema.ResponseGetModel; +import mindspore.schema.*; +import java.util.List; import java.util.ArrayList; import java.util.Date; import java.util.logging.Logger; @@ -94,7 +93,8 @@ public class GetModel { throw new IllegalArgumentException(); } RequestGetModelBuilder builder = new RequestGetModelBuilder(); - return builder.iteration(iteration).flName(name).time().build(); + return builder.iteration(iteration).flName(name).time() + .downloadCompressTypesBuilder(flParameter.getDownloadCompressTypes()).build(); } private FLClientStatus deprecatedParseResponseAlbert(ResponseGetModel responseDataBuf) { @@ -226,11 +226,29 @@ public class GetModel { return status; } + private List parseFeatureMapList(ResponseGetModel responseDataBuf) { + List featureMaps; + byte compressType = responseDataBuf.downloadCompressType(); + if (responseDataBuf.downloadCompressType() == mindspore.schema.CompressType.NO_COMPRESS) { + featureMaps = new ArrayList<>(); + for (int i = 0; i < responseDataBuf.featureMapLength(); i++) { + featureMaps.add(responseDataBuf.featureMap(i)); + } + } else { + List compressFeatureMapList = new ArrayList<>(); + for (int i = 0; i < responseDataBuf.compressFeatureMapLength(); i++) { + compressFeatureMapList.add(responseDataBuf.compressFeatureMap(i)); + } + featureMaps = DecodeExecutor.getInstance().deCompressWeight(compressType, compressFeatureMapList); + } + return featureMaps; + } + private FLClientStatus parseResponseFeatures(ResponseGetModel responseDataBuf) { FLClientStatus status; Client client = ClientManager.getClient(flParameter.getFlName()); - int fmCount = responseDataBuf.featureMapLength(); - if (fmCount <= 0) { + List featureMapList = parseFeatureMapList(responseDataBuf); + if (featureMapList.size() <= 0) { LOGGER.severe(Common.addTag("[getModel] the feature size get from server is zero")); retCode = ResponseCode.SystemError; return FLClientStatus.FAILED; @@ -239,8 +257,8 @@ public class GetModel { LOGGER.info(Common.addTag("[getModel] parseResponseFeatures by " + localFLParameter.getServerMod())); ArrayList trainFeatureMaps = new ArrayList(); ArrayList inferFeatureMaps = new ArrayList(); - for (int i = 0; i < fmCount; i++) { - FeatureMap feature = responseDataBuf.featureMap(i); + for (int i = 0; i < featureMapList.size(); i++) { + FeatureMap feature = featureMapList.get(i); if (feature == null) { LOGGER.severe(Common.addTag("[getModel] the feature returned from server is null")); retCode = ResponseCode.SystemError; @@ -289,8 +307,8 @@ public class GetModel { } else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) { LOGGER.info(Common.addTag("[getModel] parseResponseFeatures by " + localFLParameter.getServerMod())); ArrayList featureMaps = new ArrayList(); - for (int i = 0; i < fmCount; i++) { - FeatureMap feature = responseDataBuf.featureMap(i); + for (int i = 0; i < featureMapList.size(); i++) { + FeatureMap feature = featureMapList.get(i); if (feature == null) { LOGGER.severe(Common.addTag("[getModel] the feature returned from server is null")); retCode = ResponseCode.SystemError; @@ -365,6 +383,7 @@ public class GetModel { private int nameOffset = 0; private int iteration = 0; private int timeStampOffset = 0; + private int downloadCompressTypesOffset = 0; public RequestGetModelBuilder() { builder = new FlatBufferBuilder(); @@ -392,11 +411,23 @@ public class GetModel { return this; } + private RequestGetModelBuilder downloadCompressTypesBuilder(byte[] downloadCompressTypes) { + if (downloadCompressTypes == null || downloadCompressTypes.length == 0) { + LOGGER.severe(Common.addTag("[GetModel] the parameter of is null or empty," + + " please check!")); + throw new IllegalArgumentException(); + } + this.downloadCompressTypesOffset = RequestGetModel.createDownloadCompressTypesVector(builder, + downloadCompressTypes); + return this; + } + private byte[] build() { RequestGetModel.startRequestGetModel(builder); RequestGetModel.addFlName(builder, nameOffset); RequestGetModel.addIteration(builder, iteration); RequestGetModel.addTimestamp(builder, timeStampOffset); + RequestGetModel.addDownloadCompressTypes(builder, downloadCompressTypesOffset); int root = RequestGetModel.endRequestGetModel(builder); builder.finish(root); return builder.sizedByteArray(); diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/LocalFLParameter.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/LocalFLParameter.java index 624c428178d..835526ee049 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/LocalFLParameter.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/LocalFLParameter.java @@ -22,6 +22,7 @@ import org.bouncycastle.math.ec.rfc7748.X25519; import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.logging.Logger; /** @@ -83,6 +84,10 @@ public class LocalFLParameter { private MSConfig msConfig = new MSConfig(); private boolean useSSL = true; private float lr = 0.1f; + private Map oldFeatureMap; + private byte uploadCompressType = 0; + private int seed = 0; + private float uploadSparseRatio = 0.08f; private LocalFLParameter() { @@ -250,4 +255,36 @@ public class LocalFLParameter { public void setLr(float lr) { this.lr = lr; } + + public Map getOldFeatureMap() { + return oldFeatureMap; + } + + public void setOldFeatureMap(Map oldFeatureMap) { + this.oldFeatureMap = oldFeatureMap; + } + + public byte getUploadCompressType() { + return uploadCompressType; + } + + public void setUploadCompressType(byte uploadCompressType) { + this.uploadCompressType = uploadCompressType; + } + + public int getSeed() { + return seed; + } + + public void setSeed(int seed) { + this.seed = seed; + } + + public float getUploadSparseRatio() { + return uploadSparseRatio; + } + + public void setUploadSparseRatio(float uploadSparseRatio) { + this.uploadSparseRatio = uploadSparseRatio; + } } diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SecureProtocol.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SecureProtocol.java index ae72247e63f..7b50e1e729d 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SecureProtocol.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SecureProtocol.java @@ -208,35 +208,34 @@ public class SecureProtocol { * @param trainDataSize trainDataSize tne size of train data set. * @return the serialized model weights after adding masks. */ - public int[] pwMaskModel(FlatBufferBuilder builder, int trainDataSize, Map trainedMap) { + public Map> pwMaskModel(FlatBufferBuilder builder, int trainDataSize, Map trainedMap) { + Map> featureMaps = new HashMap<>(); if (featureMask == null || featureMask.length == 0) { LOGGER.severe("[Encrypt] feature mask is null, please check"); - return new int[0]; + return new HashMap<>(); } LOGGER.info(String.format("[Encrypt] feature mask size: %s", featureMask.length)); int featureSize = updateFeatureName.size(); - int[] featuresMap = new int[featureSize]; int maskIndex = 0; for (int i = 0; i < featureSize; i++) { String key = updateFeatureName.get(i); float[] data = trainedMap.get(key); + List featureMap = new ArrayList<>(); LOGGER.info(String.format("[Encrypt] feature name: %s feature size: %s", key, data.length)); for (int j = 0; j < data.length; j++) { float rawData = data[j]; if (maskIndex >= featureMask.length) { LOGGER.severe("[Encrypt] the maskIndex is out of range for array featureMask, please check"); - return new int[0]; + return new HashMap<>(); } float maskData = rawData * trainDataSize + featureMask[maskIndex]; maskIndex += 1; - data[j] = maskData; + featureMap.add(maskData); } - int featureName = builder.createString(key); - int weight = FeatureMap.createDataVector(builder, data); - int featureMap = FeatureMap.createFeatureMap(builder, featureName, weight); - featuresMap[i] = featureMap; + featureMaps.put(key, featureMap); } - return featuresMap; + return featureMaps; } /** @@ -365,7 +364,9 @@ public class SecureProtocol { * @param trainDataSize tne size of train data set. * @return the serialized model weights after adding masks. */ - public int[] dpMaskModel(FlatBufferBuilder builder, int trainDataSize, Map trainedMap) { + public Map> dpMaskModel(FlatBufferBuilder builder, int trainDataSize, + Map trainedMap) { + Map> featureMaps = new HashMap<>(); // get feature map Map mapBeforeTrain = modelMap; int featureSize = updateFeatureName.size(); @@ -383,7 +384,7 @@ public class SecureProtocol { float rawData = data[j]; if (j >= dataBeforeTrain.length) { LOGGER.severe("[Encrypt] the index j is out of range for array dataBeforeTrain, please check"); - return new int[0]; + return new HashMap<>(); } float rawDataBeforeTrain = dataBeforeTrain[j]; float updateData = rawData - rawDataBeforeTrain; @@ -393,23 +394,23 @@ public class SecureProtocol { updateL2Norm = Math.sqrt(updateL2Norm); if (updateL2Norm == 0) { LOGGER.severe(Common.addTag("[Encrypt] updateL2Norm is 0, please check")); - return new int[0]; + return new HashMap<>(); } double clipFactor = Math.min(1.0, dpNormClip / updateL2Norm); // clip and add noise - int[] featuresMap = new int[featureSize]; for (int i = 0; i < featureSize; i++) { String key = updateFeatureName.get(i); if (!trainedMap.containsKey(key)) { LOGGER.severe("[Encrypt] the key: " + key + " is not in map, please check!"); - return new int[0]; + return new HashMap<>(); } float[] data = trainedMap.get(key); float[] data2 = new float[data.length]; + List featureMap = new ArrayList<>(); if (!mapBeforeTrain.containsKey(key)) { LOGGER.severe("[Encrypt] the key: " + key + " is not in mapBeforeTrain, please check!"); - return new int[0]; + return new HashMap<>(); } float[] dataBeforeTrain = mapBeforeTrain.get(key); @@ -419,7 +420,7 @@ public class SecureProtocol { float rawData = data[j]; if (j >= dataBeforeTrain.length) { LOGGER.severe("[Encrypt] the index j is out of range for array dataBeforeTrain, please check"); - return new int[0]; + return new HashMap<>(); } float rawDataBeforeTrain = dataBeforeTrain[j]; float updateData = rawData - rawDataBeforeTrain; @@ -432,13 +433,11 @@ public class SecureProtocol { updateData += gaussianNoise; data2[j] = rawDataBeforeTrain + updateData; data2[j] = data2[j] * trainDataSize; + featureMap.add(data2[j]); } - int featureName = builder.createString(key); - int weight = FeatureMap.createDataVector(builder, data2); - int featureMap = FeatureMap.createFeatureMap(builder, featureName, weight); - featuresMap[i] = featureMap; + featureMaps.put(key, featureMap); } - return featuresMap; + return featureMaps; } /** diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/StartFLJob.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/StartFLJob.java index e3b9224101e..125c52970ea 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/StartFLJob.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/StartFLJob.java @@ -18,6 +18,7 @@ package com.mindspore.flclient; import com.google.flatbuffers.FlatBufferBuilder; +import com.mindspore.flclient.compression.DecodeExecutor; import com.mindspore.flclient.model.AlInferBert; import com.mindspore.flclient.model.AlTrainBert; import com.mindspore.flclient.model.Client; @@ -29,6 +30,7 @@ import com.mindspore.flclient.model.TrainLenet; import com.mindspore.flclient.pki.PkiBean; import com.mindspore.flclient.pki.PkiUtil; +import mindspore.schema.*; import mindspore.schema.FLPlan; import mindspore.schema.FeatureMap; import mindspore.schema.RequestFLJob; @@ -38,6 +40,7 @@ import mindspore.schema.ResponseFLJob; import java.io.IOException; import java.security.cert.Certificate; import java.util.ArrayList; +import java.util.List; import java.util.logging.Logger; import static com.mindspore.flclient.LocalFLParameter.ALBERT; @@ -119,6 +122,7 @@ public class StartFLJob { .iteration(iteration) .signData(pkiBean.getSignData()) .certificateChain(pkiBean.getCertificates()) + .downloadCompressTypesBuilder(flParameter.getDownloadCompressTypes()) .build(); } return builder.flName(flParameter.getFlName()) @@ -126,6 +130,7 @@ public class StartFLJob { .id(localFLParameter.getFlID()) .dataSize(dataSize) .iteration(iteration) + .downloadCompressTypesBuilder(flParameter.getDownloadCompressTypes()) .build(); } @@ -151,8 +156,9 @@ public class StartFLJob { ArrayList albertFeatureMaps = new ArrayList(); ArrayList inferFeatureMaps = new ArrayList(); featureSize = 0; - for (int i = 0; i < fmCount; i++) { - FeatureMap feature = flJob.featureMap(i); + List featureMapList = parseFeatureMapList(flJob); + for (int i = 0; i < featureMapList.size(); i++) { + FeatureMap feature = featureMapList.get(i); if (feature == null) { LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null")); return FLClientStatus.FAILED; @@ -233,12 +239,14 @@ public class StartFLJob { private FLClientStatus deprecatedParseResponseLenet(ResponseFLJob flJob) { FLClientStatus status; - int fmCount = flJob.featureMapLength(); - ArrayList featureMaps = new ArrayList(); updateFeatureName.clear(); featureSize = 0; - for (int i = 0; i < fmCount; i++) { - FeatureMap feature = flJob.featureMap(i); + List featureMapList = parseFeatureMapList(flJob); + + ArrayList featureMaps = new ArrayList<>(); + + for (int i = 0; i < featureMapList.size(); i++) { + FeatureMap feature = featureMapList.get(i); if (feature == null) { LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null")); return FLClientStatus.FAILED; @@ -267,6 +275,24 @@ public class StartFLJob { return FLClientStatus.SUCCESS; } + private List parseFeatureMapList(ResponseFLJob flJob) { + List featureMaps; + byte compressType = flJob.downloadCompressType(); + if (flJob.downloadCompressType() == mindspore.schema.CompressType.NO_COMPRESS) { + LOGGER.info(Common.addTag("[parseFeatureMapList] create no compress feature map.")); + featureMaps = new ArrayList<>(); + for (int i = 0; i < flJob.featureMapLength(); i++) { + featureMaps.add(flJob.featureMap(i)); + } + } else { + List compressFeatureMapList = new ArrayList<>(); + for (int i = 0; i < flJob.compressFeatureMapLength(); i++) { + compressFeatureMapList.add(flJob.compressFeatureMap(i)); + } + featureMaps = DecodeExecutor.getInstance().deCompressWeight(compressType, compressFeatureMapList); + } + return featureMaps; + } private FLClientStatus hybridFeatures(ResponseFLJob flJob) { FLClientStatus status; @@ -275,8 +301,23 @@ public class StartFLJob { ArrayList trainFeatureMaps = new ArrayList(); ArrayList inferFeatureMaps = new ArrayList(); featureSize = 0; + List featureMaps; + byte compressType = flJob.downloadCompressType(); + if (compressType == CompressType.NO_COMPRESS) { + featureMaps = new ArrayList<>(); + for (int i = 0; i < fmCount; i++) { + featureMaps.add(flJob.featureMap(i)); + } + } else { + List compressFeatureMapList = new ArrayList<>(); + for (int i = 0; i < flJob.compressFeatureMapLength(); i++) { + compressFeatureMapList.add(flJob.compressFeatureMap(i)); + } + featureMaps = DecodeExecutor.getInstance().deCompressWeight(compressType, compressFeatureMapList); + fmCount = featureMaps.size(); + } for (int i = 0; i < fmCount; i++) { - FeatureMap feature = flJob.featureMap(i); + FeatureMap feature = featureMaps.get(i); if (feature == null) { LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null")); retCode = ResponseCode.SystemError; @@ -335,8 +376,23 @@ public class StartFLJob { int fmCount = flJob.featureMapLength(); ArrayList featureMaps = new ArrayList(); featureSize = 0; + byte compressType = flJob.downloadCompressType(); + List parseFeatureMaps; + if (compressType == CompressType.NO_COMPRESS) { + parseFeatureMaps = new ArrayList<>(); + for (int i = 0; i < fmCount; i++) { + parseFeatureMaps.add(flJob.featureMap(i)); + } + } else { + List compressFeatureMapList = new ArrayList<>(); + for (int i = 0; i < flJob.compressFeatureMapLength(); i++) { + compressFeatureMapList.add(flJob.compressFeatureMap(i)); + } + parseFeatureMaps = DecodeExecutor.getInstance().deCompressWeight(compressType, compressFeatureMapList); + fmCount = parseFeatureMaps.size(); + } for (int i = 0; i < fmCount; i++) { - FeatureMap feature = flJob.featureMap(i); + FeatureMap feature = parseFeatureMaps.get(i); if (feature == null) { LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null")); retCode = ResponseCode.SystemError; @@ -437,8 +493,8 @@ public class StartFLJob { switch (responseRetCode) { case (ResponseCode.SUCCEED): - if (flJob.featureMapLength() <= 0) { - LOGGER.severe(Common.addTag("[startFLJob] the feature size get from server is zero")); + if (flJob.downloadCompressType() == CompressType.NO_COMPRESS && flJob.featureMapLength() <= 0) { + LOGGER.warning(Common.addTag("[startFLJob] the feature size get from server is zero")); retCode = ResponseCode.SystemError; return FLClientStatus.FAILED; } @@ -484,6 +540,7 @@ public class StartFLJob { private int equipCertOffset = 0; private int equipCACertOffset = 0; private int rootCertOffset = 0; + private int downloadCompressTypesOffset = 0; public RequestStartFLJobBuilder() { builder = new FlatBufferBuilder(); @@ -598,6 +655,17 @@ public class StartFLJob { return this; } + private RequestStartFLJobBuilder downloadCompressTypesBuilder(byte[] downloadCompressTypes) { + if (downloadCompressTypes == null || downloadCompressTypes.length == 0) { + LOGGER.severe(Common.addTag("[StartFLJob] the parameter of is null or empty," + + " please check!")); + throw new IllegalArgumentException(); + } + this.downloadCompressTypesOffset = RequestFLJob.createDownloadCompressTypesVector(builder, + downloadCompressTypes); + return this; + } + /** * build protobuffer * @@ -615,6 +683,7 @@ public class StartFLJob { RequestFLJob.addEquipCaCert(builder, equipCACertOffset); RequestFLJob.addEquipCert(builder, equipCertOffset); RequestFLJob.addKeyAttestation(builder, keyAttestationOffset); + RequestFLJob.addDownloadCompressTypes(builder, downloadCompressTypesOffset); int root = RequestFLJob.endRequestFLJob(builder); builder.finish(root); return builder.sizedByteArray(); diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SyncFLJob.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SyncFLJob.java index 5343a6657b9..29641f41b8f 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SyncFLJob.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SyncFLJob.java @@ -147,6 +147,10 @@ public class SyncFLJob { LOGGER.info(Common.addTag("[startFLJob] startFLJob succeed, curIteration: " + flLiteClient.getIteration())); updateTryTimePerIter(flLiteClient); + // Copy weights before training. + Map oldFeatureMap = flLiteClient.getFeatureMap(); + localFLParameter.setOldFeatureMap(oldFeatureMap); + // create mask curStatus = flLiteClient.getFeatureMask(); if (curStatus == FLClientStatus.RESTART) { diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/UpdateModel.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/UpdateModel.java index 5e7661dbe4c..0dda5958a1b 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/UpdateModel.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/UpdateModel.java @@ -26,11 +26,15 @@ import com.mindspore.flclient.model.SessionUtil; import com.mindspore.flclient.model.Status; import com.mindspore.flclient.model.TrainLenet; import com.mindspore.lite.MSTensor; +import com.mindspore.flclient.compression.EncodeExecutor; +import com.mindspore.flclient.compression.CompressWeight; import mindspore.schema.FeatureMap; +import mindspore.schema.CompressFeatureMap; import mindspore.schema.RequestUpdateModel; import mindspore.schema.ResponseCode; import mindspore.schema.ResponseUpdateModel; +import static mindspore.schema.CompressType.NO_COMPRESS; import java.util.ArrayList; import java.util.Date; @@ -208,6 +212,7 @@ public class UpdateModel { private RequestUpdateModel requestUM; private FlatBufferBuilder builder; private int fmOffset = 0; + private int compFmOffset = 0; private int nameOffset = 0; private int idOffset = 0; private int timestampOffset = 0; @@ -215,8 +220,11 @@ public class UpdateModel { private int sign = 0; private int indexArrayOffset = 0; private int iteration = 0; + private byte uploadCompressType = 0; + private float uploadSparseRate = 0.0f; private EncryptLevel encryptLevel = EncryptLevel.NOT_ENCRYPT; private float uploadLossOffset = 0.0f; + private int nameVecOffset = 0; private RequestUpdateModelBuilder(EncryptLevel encryptLevel) { builder = new FlatBufferBuilder(); @@ -294,34 +302,33 @@ public class UpdateModel { } else { trainedMap = getFeatureMap(); } + Map> featureMaps = new HashMap<>(); long startTime; long endTime; switch (encryptLevel) { case PW_ENCRYPT: - int[] fmOffsetsPW = secureProtocol.pwMaskModel(builder, trainDataSize, trainedMap); - if (fmOffsetsPW == null || fmOffsetsPW.length == 0) { - LOGGER.severe("[Encrypt] the return fmOffsetsPW from is " + + featureMaps = secureProtocol.pwMaskModel(builder, trainDataSize, trainedMap); + if (featureMaps == null || featureMaps.size() == 0) { + LOGGER.severe("[Encrypt] the return featureMaps from is " + "null, please check"); throw new IllegalArgumentException(); } - this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsPW); LOGGER.info(Common.addTag("[Encrypt] pairwise mask model ok!")); - return this; + break; case DP_ENCRYPT: startTime = System.currentTimeMillis(); - int[] fmOffsetsDP = secureProtocol.dpMaskModel(builder, trainDataSize, trainedMap); - if (fmOffsetsDP == null || fmOffsetsDP.length == 0) { - LOGGER.severe("[Encrypt] the return fmOffsetsDP from is " + + featureMaps = secureProtocol.dpMaskModel(builder, trainDataSize, trainedMap); + if (featureMaps == null || featureMaps.size() == 0) { + LOGGER.severe("[Encrypt] the return featureMaps from is " + "null, please check"); retCode = ResponseCode.RequestError; status = FLClientStatus.FAILED; throw new IllegalArgumentException(); } - this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsDP); LOGGER.info(Common.addTag("[Encrypt] DP mask model ok!")); endTime = System.currentTimeMillis(); - LOGGER.info(Common.addTag("[Encrypt] dp time is: " + (endTime - startTime) + "ms")); - return this; + LOGGER.info(Common.addTag("dp time is " + (endTime - startTime) + "ms")); + break; case SIGNDS: startTime = System.currentTimeMillis(); // signds alg return indexArray, and package indexArray into flatbuffer. @@ -352,31 +359,104 @@ public class UpdateModel { this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsSignds); LOGGER.info(Common.addTag("[Encrypt] SignDS mask model ok!")); endTime = System.currentTimeMillis(); - LOGGER.info(Common.addTag("[Encrypt] signds time is: " + (endTime - startTime) + "ms")); + LOGGER.info(Common.addTag("signds time is " + (endTime - startTime) + "ms")); return this; case NOT_ENCRYPT: default: startTime = System.currentTimeMillis(); - int featureSize = updateFeatureName.size(); - int[] fmOffsets = new int[featureSize]; - for (int i = 0; i < featureSize; i++) { - String key = updateFeatureName.get(i); - float[] data = trainedMap.get(key); - LOGGER.info(Common.addTag("[updateModel build featuresMap] feature name: " + key + " feature " + - "size: " + data.length)); - for (int j = 0; j < data.length; j++) { - data[j] = data[j] * trainDataSize; + for (String name : updateFeatureName) { + float[] data = trainedMap.get(name); + List featureMap = new ArrayList<>(); + for (float datum : data) { + featureMap.add(datum * (float) trainDataSize); } - int featureName = builder.createString(key); - int weight = FeatureMap.createDataVector(builder, data); - int featureMap = FeatureMap.createFeatureMap(builder, featureName, weight); - fmOffsets[i] = featureMap; + featureMaps.put(name, featureMap); } - this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsets); endTime = System.currentTimeMillis(); - LOGGER.info(Common.addTag("[Encrypt] not encrypt time is: " + (endTime - startTime) + "ms")); - return this; + LOGGER.info(Common.addTag("not encrypt time is " + (endTime - startTime) + "ms")); + break; } + byte uploadCompressType = localFLParameter.getUploadCompressType(); + if (uploadCompressType != NO_COMPRESS) { + startTime = System.currentTimeMillis(); + this.compFmOffset = buildCompFmOffset(featureMaps, trainDataSize); + this.uploadCompressType = localFLParameter.getUploadCompressType(); + this.uploadSparseRate = localFLParameter.getUploadSparseRatio(); + this.nameVecOffset = buildNameVecOffset(updateFeatureName); + endTime = System.currentTimeMillis(); + LOGGER.info(Common.addTag("compression time is " + (endTime - startTime) + "ms")); + return this; + } + this.fmOffset = buildFmOffset(featureMaps, updateFeatureName); + return this; + } + + private int buildCompFmOffset(Map> featureMaps, int trainDataSize) { + List compressWeights = EncodeExecutor.getInstance().encode(featureMaps, trainDataSize); + if (compressWeights == null || compressWeights.size() == 0) { + LOGGER.severe("[Compression] the return compressWeights from is " + + "null, please check"); + retCode = ResponseCode.RequestError; + status = FLClientStatus.FAILED; + throw new IllegalArgumentException(); + } + int compFeatureSize = compressWeights.size(); + int[] compFmOffsets = new int[compFeatureSize]; + int index = 0; + for (CompressWeight compressWeight : compressWeights) { + String weightFullname = compressWeight.getWeightFullname(); + List compressData = compressWeight.getCompressData(); + float minVal = compressWeight.getMinValue(); + float maxVal = compressWeight.getMaxValue(); + byte[] data = new byte[compressData.size()]; + LOGGER.info(Common.addTag("[updateModel build compressWeight] feature name: " + + weightFullname + ", feature size: " + data.length)); + for (int j = 0; j < data.length; j++) { + data[j] = compressData.get(j); + } + int featureName = builder.createString(weightFullname); + int weight = CompressFeatureMap.createCompressDataVector(builder, data); + int featureMap = CompressFeatureMap.createCompressFeatureMap(builder, featureName, weight, + minVal, maxVal); + LOGGER.info(Common.addTag("[Compression]" + + " featureName: " + weightFullname + + ", min_val: " + minVal + + ", max_val: " + maxVal)); + compFmOffsets[index] = featureMap; + index += 1; + } + return RequestUpdateModel.createCompressFeatureMapVector(builder, compFmOffsets); + } + + private int buildNameVecOffset(ArrayList updateFeatureName) { + int featureSize = updateFeatureName.size(); + int[] nameVecOffsets = new int[featureSize]; + for (int i = 0; i < featureSize; i++) { + String key = updateFeatureName.get(i); + int featureName = builder.createString(key); + nameVecOffsets[i] = featureName; + } + return RequestUpdateModel.createNameVecVector(builder, nameVecOffsets); + } + + private int buildFmOffset(Map> featureMaps, ArrayList updateFeatureName) { + int featureSize = updateFeatureName.size(); + int[] fmOffsets = new int[featureSize]; + for (int i = 0; i < featureSize; i++) { + String key = updateFeatureName.get(i); + List featureMap = featureMaps.get(key); + float[] data = new float[featureMap.size()]; + LOGGER.info(Common.addTag("[updateModel build featuresMap] feature name: " + key + " feature " + + "size: " + data.length)); + for (int j = 0; j < data.length; j++) { + data[j] = featureMap.get(j); + } + int featureName = builder.createString(key); + int weight = FeatureMap.createDataVector(builder, data); + int featureMapOff = FeatureMap.createFeatureMap(builder, featureName, weight); + fmOffsets[i] = featureMapOff; + } + return RequestUpdateModel.createFeatureMapVector(builder, fmOffsets); } /** @@ -417,6 +497,10 @@ public class UpdateModel { RequestUpdateModel.addFlId(this.builder, idOffset); RequestUpdateModel.addTimestamp(builder, this.timestampOffset); RequestUpdateModel.addIteration(builder, this.iteration); + RequestUpdateModel.addCompressFeatureMap(builder, this.compFmOffset); + RequestUpdateModel.addUploadCompressType(builder, this.uploadCompressType); + RequestUpdateModel.addUploadSparseRate(builder, this.uploadSparseRate); + RequestUpdateModel.addNameVec(builder, this.nameVecOffset); RequestUpdateModel.addFeatureMap(builder, this.fmOffset); RequestUpdateModel.addSignature(builder, this.signDataOffset); RequestUpdateModel.addUploadLoss(builder, this.uploadLossOffset); diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/compression/CompressMode.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/compression/CompressMode.java new file mode 100644 index 00000000000..890ed26c6d9 --- /dev/null +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/compression/CompressMode.java @@ -0,0 +1,40 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2019-2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mindspore.flclient.compression; + +import java.util.HashMap; +import java.util.Map; + +import static mindspore.schema.CompressType.NO_COMPRESS; +import static mindspore.schema.CompressType.QUANT; + +/** + * The compress mod. + * + * @since 2021-12-21 + */ + +public class CompressMode { + // compress type -> num bits + public static final Map COMPRESS_TYPE_MAP = new HashMap<>(); + + static { + COMPRESS_TYPE_MAP.put(NO_COMPRESS, -1); + COMPRESS_TYPE_MAP.put(QUANT, 8); + } + +} \ No newline at end of file diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/compression/CompressWeight.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/compression/CompressWeight.java new file mode 100644 index 00000000000..b80ba70a1bf --- /dev/null +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/compression/CompressWeight.java @@ -0,0 +1,83 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2019-2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mindspore.flclient.compression; + +import java.util.List; + +/** + * Compress Weight Bean + * + * @since 2021-12-21 + */ +public class CompressWeight { + private String weightFullname; + private List compressData; + private float minValue; + private float maxValue; + + public CompressWeight() { + } + + public CompressWeight(String weightFullname, List compressData, float minValue, float maxValue) { + this.weightFullname = weightFullname; + this.compressData = compressData; + this.minValue = minValue; + this.maxValue = maxValue; + } + + public String getWeightFullname() { + return weightFullname; + } + + public void setWeightFullname(String weightFullname) { + this.weightFullname = weightFullname; + } + + public List getCompressData() { + return compressData; + } + + public void setCompressData(List compressData) { + this.compressData = compressData; + } + + public float getMinValue() { + return minValue; + } + + public void setMinValue(float minValue) { + this.minValue = minValue; + } + + public float getMaxValue() { + return maxValue; + } + + public void setMaxValue(float maxValue) { + this.maxValue = maxValue; + } + + @Override + public String toString() { + return "CompressWeight{" + + "weightFullname='" + weightFullname + '\'' + + ", compressData=" + compressData + + ", minValue=" + minValue + + ", maxValue=" + maxValue + + '}'; + } +} diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/compression/DecodeExecutor.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/compression/DecodeExecutor.java new file mode 100644 index 00000000000..0d72ce7aeec --- /dev/null +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/compression/DecodeExecutor.java @@ -0,0 +1,115 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2019-2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mindspore.flclient.compression; + +import com.google.flatbuffers.FlatBufferBuilder; + +import com.mindspore.flclient.Common; +import com.mindspore.flclient.StartFLJob; +import mindspore.schema.CompressFeatureMap; +import mindspore.schema.FeatureMap; +import mindspore.schema.CompressType; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.logging.Logger; + +import static mindspore.schema.CompressType.QUANT; +/** + * Compress Executor + * + * @since 2021-12-21 + */ +public class DecodeExecutor { + private static final Logger LOGGER = Logger.getLogger(DecodeExecutor.class.toString()); + + private static volatile DecodeExecutor compressExecutor; + + private DecodeExecutor() {} + + public static DecodeExecutor getInstance() { + if (compressExecutor == null) { + synchronized (DecodeExecutor.class) { + if (compressExecutor == null) { + compressExecutor = new DecodeExecutor(); + } + } + } + return compressExecutor; + } + + public List deCompressWeight(byte compressType, List compressFeatureMapList) { + if (!CompressMode.COMPRESS_TYPE_MAP.containsKey(compressType)) { + return new ArrayList<>(); + } + LOGGER.info(Common.addTag("[deCompressWeight] create " + CompressType.name(compressType) + " feature map.")); + int num_bits = CompressMode.COMPRESS_TYPE_MAP.get(compressType); + if (compressType == QUANT) { + return deCompressQuantMinMax(compressFeatureMapList, num_bits); + } + return new ArrayList<>(); + } + + private List deCompressQuantMinMax(List compressFeatureMapList, int num_bits) { + float temp1 = (float) (Math.pow(2, num_bits) - 1); + float temp2 = (float) Math.pow(2, num_bits - 1); + + Map deCompressFeatureMaps = new HashMap<>(); + int compressFeatureMapLength = compressFeatureMapList.size(); + for (int i = 0; i < compressFeatureMapLength; i++) { + CompressFeatureMap compressFeatureMap = compressFeatureMapList.get(i); + String weightName = compressFeatureMap.weightFullname(); + int compressDataLength = compressFeatureMap.compressDataLength(); + List compressWeightList = new ArrayList<>(); + for (int j = 0; j < compressDataLength; j++) { + compressWeightList.add(compressFeatureMap.compressData(j)); + } + float minVal = compressFeatureMap.minVal(); + float maxVal = compressFeatureMap.maxVal(); + float scale_value = (float) ((maxVal - minVal) / temp1 + 1e-10); + float[] params = new float[compressWeightList.size()]; + for (int j = 0; j < params.length; j++) { + float val = (compressWeightList.get(j).intValue() + temp2) * scale_value + minVal; + params[j] = val; + } + deCompressFeatureMaps.put(weightName, params); + } + + List featureMaps = new ArrayList<>(); + for (String weightName : deCompressFeatureMaps.keySet()) { + FlatBufferBuilder builder = new FlatBufferBuilder(0); + int weightFullnameOffset = builder.createString(weightName); + float[] data = deCompressFeatureMaps.get(weightName); + int dataOffset = FeatureMap.createDataVector(builder, data); + + FeatureMap.startFeatureMap(builder); + FeatureMap.addWeightFullname(builder, weightFullnameOffset); + FeatureMap.addData(builder, dataOffset); + + int orc = FeatureMap.endFeatureMap(builder); + builder.finish(orc); + ByteBuffer buf = builder.dataBuffer(); + FeatureMap featureMap = FeatureMap.getRootAsFeatureMap(buf); + + featureMaps.add(featureMap); + } + return featureMaps; + } +} \ No newline at end of file diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/compression/EncodeExecutor.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/compression/EncodeExecutor.java new file mode 100644 index 00000000000..7f3eb42864a --- /dev/null +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/compression/EncodeExecutor.java @@ -0,0 +1,167 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2019-2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mindspore.flclient.compression; + +import com.mindspore.flclient.LocalFLParameter; +import static mindspore.schema.CompressType.DIFF_SPARSE_QUANT; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.PriorityQueue; + +/** + * Encode Executor + * + * @since 2021-12-21 + */ +public class EncodeExecutor { + private final LocalFLParameter localFLParameter = LocalFLParameter.getInstance(); + + private static volatile EncodeExecutor encodeExecutor; + + private EncodeExecutor() {} + + public static EncodeExecutor getInstance() { + if (encodeExecutor == null) { + synchronized (EncodeExecutor.class) { + if (encodeExecutor == null) { + encodeExecutor = new EncodeExecutor(); + } + } + } + return encodeExecutor; + } + + private static final int multiplier = 2147483647; + private static final double increment = 4294967294.0; + private static final int modulo = 48271; + + private List constructMaskArray(int paramNum) { + int seed = localFLParameter.getSeed(); + float uploadSparseRatio = localFLParameter.getUploadSparseRatio(); + + List maskArray = new ArrayList<>(); + + int retain_num = (int) ((float) (paramNum) * uploadSparseRatio); + for (int i = 0; i < retain_num; ++i) { + maskArray.add(1); + } + for (int i = retain_num; i < paramNum; ++i) { + maskArray.add(0); + } + + seed = ((seed + multiplier) * modulo) % multiplier; + for (int i = 0; i < paramNum; ++i) { + // generate random number in (0, 1) + double rand = (double)(seed) / increment + 0.5; + // update seed + seed = (seed * modulo) % multiplier; + + int j = (int)(rand * (double)(paramNum - i)) + i; + int temp = maskArray.get(i); + maskArray.set(i, maskArray.get(j)); + maskArray.set(j, temp); + } + return maskArray; + } + + public List enDiffSparseQuant(Map> featureMaps, int numBits, + int trainDataSize) { + List compressWeights = new ArrayList<>(); + + // difference encode + Map oldFeatureMap = localFLParameter.getOldFeatureMap(); + Map> diffFeatureMaps = new HashMap<>(); + for (String featureMapName : featureMaps.keySet()) { + List diffs = new ArrayList<>(); + List featureMap = featureMaps.get(featureMapName); + float[] dataBeforeTrain = oldFeatureMap.get(featureMapName); + int length = dataBeforeTrain.length; + for (int i = 0; i < length; ++i) { + float diff = featureMap.get(i) - dataBeforeTrain[i] * (float) trainDataSize; + diffs.add(diff); + } + diffFeatureMaps.put(featureMapName, diffs); + } + + // sparse encode + int paramNum = 0; + for (String featureMapName : diffFeatureMaps.keySet()) { + int weightSize = diffFeatureMaps.get(featureMapName).size(); + paramNum += weightSize; + } + List maskArray = constructMaskArray(paramNum); + + Map> sparseFeatureMaps = new HashMap<>(); + int index = 0; + for (String featureMapName : diffFeatureMaps.keySet()) { + List sparseFeatureMap = new ArrayList<>(); + List Weight = diffFeatureMaps.get(featureMapName); + for (Float dataValue : Weight) { + if (maskArray.get(index) == 1) { + sparseFeatureMap.add(dataValue); + } + index += 1; + } + sparseFeatureMaps.put(featureMapName, sparseFeatureMap); + } + + // quant encode + float temp1 = (float) (1 << numBits) - 1.0f; + float temp2 = (float) (1 << (numBits - 1)); + for (String featureMapName : sparseFeatureMaps.keySet()) { + CompressWeight compressWeight = new CompressWeight(); + compressWeight.setWeightFullname(featureMapName); + + List sparseFeatureMap = sparseFeatureMaps.get(featureMapName); + + // get min and max value + Float minVal = Float.MAX_VALUE; + float maxVal = -minVal; + for (Float value : sparseFeatureMap) { + if (value < minVal) { + minVal = value; + } + if (value > maxVal) { + maxVal = value; + } + } + compressWeight.setMinValue(minVal); + compressWeight.setMaxValue(maxVal); + float scale_value = (maxVal - minVal) / temp1 + 1e-10f; + List compressData = new ArrayList<>(); + for (Float aFloat : sparseFeatureMap) { + compressData.add((byte) (Math.round((aFloat - minVal) / scale_value - temp2))); + } + compressWeight.setCompressData(compressData); + compressWeights.add(compressWeight); + } + + return compressWeights; + } + + public List encode(Map> featureMaps, int trainDataSize) { + byte uploadCompressType = localFLParameter.getUploadCompressType(); + if (uploadCompressType == DIFF_SPARSE_QUANT) { + return enDiffSparseQuant(featureMaps, 8, trainDataSize); + } + throw new IllegalArgumentException(); + } +} \ No newline at end of file diff --git a/mindspore/python/mindspore/parallel/_ps_context.py b/mindspore/python/mindspore/parallel/_ps_context.py index df79d8755f8..fc9dc906b2b 100644 --- a/mindspore/python/mindspore/parallel/_ps_context.py +++ b/mindspore/python/mindspore/parallel/_ps_context.py @@ -15,8 +15,9 @@ """Context for parameter server training mode""" import os -from mindspore._checkparam import Validator +from mindspore._checkparam import Validator, Rel from mindspore._c_expression import PSContext +from mindspore import log as logger _ps_context = None @@ -79,6 +80,9 @@ _set_ps_context_func_map = { "sign_global_lr": ps_context().set_sign_global_lr, "sign_dim_out": ps_context().set_sign_dim_out, "checkpoint_dir": ps_context().set_checkpoint_dir, + "upload_compress_type": ps_context().set_upload_compress_type, + "upload_sparse_rate": ps_context().set_upload_sparse_rate, + "download_compress_type": ps_context().set_download_compress_type, } _get_ps_context_func_map = { @@ -126,7 +130,10 @@ _get_ps_context_func_map = { "sign_thr_ratio": ps_context().sign_thr_ratio, "sign_global_lr": ps_context().sign_global_lr, "sign_dim_out": ps_context().sign_dim_out, - "checkpoint_dir": ps_context().checkpoint_dir + "checkpoint_dir": ps_context().checkpoint_dir, + "upload_compress_type": ps_context().upload_compress_type, + "upload_sparse_rate": ps_context().upload_sparse_rate, + "download_compress_type": ps_context().download_compress_type, } _check_positive_int_keys = ["server_num", "scheduler_port", "fl_server_port", @@ -140,6 +147,15 @@ _check_positive_float_keys = ["update_model_ratio", "client_learning_rate"] _check_port_keys = ["scheduler_port", "fl_server_port"] +_check_string_keys = { + "upload_compress_type": ["NO_COMPRESS", "DIFF_SPARSE_QUANT"], + "download_compress_type": ["NO_COMPRESS", "QUANT"], +} + +_check_float_range_keys = { + "upload_sparse_rate": {"lower_limit": 0.0, "upper_limit": 1.0, "rel": Rel.INC_RIGHT}, +} + def _get_ps_mode_rank(): ps_rank = ps_context().ps_rank_id() if ps_rank == -1: @@ -183,6 +199,7 @@ def _set_ps_context(**kwargs): Examples: >>> context.set_ps_context(enable_ps=True, enable_ssl=True, client_password='123456', server_password='123456') """ + kwargs = _check_conflict_value(kwargs) for key, value in kwargs.items(): if key not in _set_ps_context_func_map: raise ValueError("Set PS context keyword %s is not recognized!" % key) @@ -287,6 +304,31 @@ def _check_value(key, value): if key in _check_positive_float_keys: Validator.check_positive_float(value, key) + if key in _check_string_keys: + try: + string_keys = _check_string_keys[key] + Validator.check_string(value, string_keys) + except KeyError: + pass + + if key in _check_float_range_keys: + try: + range_keys = _check_float_range_keys[key] + Validator.check_float_range(value, **range_keys) + except KeyError: + pass + if key in _check_port_keys: if value < 1 or value > 65535: raise ValueError("The range of %s must be 1 to 65535, but got %d." % (key, value)) + + +def _check_conflict_value(kwargs): + if "upload_compress_type" in kwargs and " encrypt_type" in kwargs: + if kwargs["upload_compress_type"] != "NO_COMPRESS" and kwargs["encrypt_type"] in ("SIGNDS", "PW_ENCRYPT"): + logger.warning("The '{}' and '{}' are conflicted, and in '{}' mode the" + " 'upload_compress_type' will be 'NO_COMPRESS'".format(kwargs["encrypt_type"], + kwargs["upload_compress_type"], + kwargs["encrypt_type"])) + kwargs["upload_compress_type"] = "NO_COMPRESS" + return kwargs diff --git a/mindspore/schema/fl_job.fbs b/mindspore/schema/fl_job.fbs index d443c3a7153..dd4d74188da 100644 --- a/mindspore/schema/fl_job.fbs +++ b/mindspore/schema/fl_job.fbs @@ -47,6 +47,16 @@ table FeatureMap{ weight_fullname:string; data:[float]; } + +enum CompressType:byte {NO_COMPRESS = 0, DIFF_SPARSE_QUANT = 1, QUANT = 2} + +table CompressFeatureMap{ + weight_fullname:string; + compress_data:[int8]; + min_val:float; + max_val:float; +} + table RequestFLJob{ fl_name:string; fl_id:string; @@ -58,6 +68,7 @@ table RequestFLJob{ equip_cert:string; equip_ca_cert:string; root_cert:string; + download_compress_types:[CompressType]; } table ResponseFLJob { retcode:int; @@ -68,6 +79,10 @@ table ResponseFLJob { fl_plan_config:FLPlan; feature_map:[FeatureMap]; timestamp:string; + upload_compress_type:CompressType; + upload_sparse_rate:float; + download_compress_type:CompressType; + compress_feature_map:[CompressFeatureMap]; } table FLPlan { @@ -94,6 +109,10 @@ table RequestUpdateModel{ upload_loss:float; sign:int; index_array:[int]; + compress_feature_map:[CompressFeatureMap]; + upload_compress_type:CompressType; + upload_sparse_rate:float; + name_vec:[string]; } table ResponseUpdateModel{ @@ -132,6 +151,7 @@ table RequestGetModel{ fl_name:string; iteration:int; timestamp:string; + download_compress_types:[CompressType]; } table ResponseGetModel{ retcode:int; @@ -139,6 +159,8 @@ table ResponseGetModel{ iteration:int; feature_map:[FeatureMap]; timestamp:string; + download_compress_type:CompressType; + compress_feature_map:[CompressFeatureMap]; } table RequestAsyncGetModel{