!21202 Fix hybrid training with pairwise encrypt

Merge pull request !21202 from ZPaC/fix-hybrid-with-pw-mask
This commit is contained in:
i-robot 2021-08-03 02:50:56 +00:00 committed by Gitee
commit 04290b661a
12 changed files with 95 additions and 15 deletions

View File

@ -231,6 +231,9 @@ bool Executor::IsWeightAggrDone(const std::vector<std::string> &param_names) {
std::unique_lock<std::mutex> lock(mtx);
auto &param_aggr = param_aggrs_[name];
MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false);
if (!param_aggr->requires_aggr()) {
continue;
}
if (!param_aggr->IsAggregationDone()) {
MS_LOG(DEBUG) << "Update model for " << name << " is not done yet.";
return false;
@ -265,6 +268,8 @@ std::map<std::string, AddressPtr> Executor::GetModel() {
return model;
}
const std::vector<std::string> &Executor::param_names() const { return param_names_; }
bool Executor::Unmask() {
#ifdef ENABLE_ARMOUR
auto model = GetModel();
@ -274,7 +279,17 @@ bool Executor::Unmask() {
#endif
}
const std::vector<std::string> &Executor::param_names() const { return param_names_; }
void Executor::set_unmasked(bool unmasked) { unmasked_ = unmasked; }
bool Executor::unmasked() const {
std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
if (encrypt_type == ps::kPWEncryptType) {
return unmasked_.load();
} else {
// If the algorithm of pairwise encrypt is not enabled, consider_ unmasked flag as true.
return true;
}
}
std::string Executor::GetTrainableParamName(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);

View File

@ -93,10 +93,16 @@ class Executor {
bool initialized() const;
const std::vector<std::string> &param_names() const;
// The unmasking method for pairwise encrypt algorithm.
bool Unmask();
// The setter and getter for unmasked flag to judge whether the unmasking is completed.
void set_unmasked(bool unmasked);
bool unmasked() const;
private:
Executor() : initialized_(false), aggregation_count_(0), param_names_({}), param_aggrs_({}) {}
Executor() : initialized_(false), aggregation_count_(0), param_names_({}), param_aggrs_({}), unmasked_(false) {}
~Executor() = default;
Executor(const Executor &) = delete;
Executor &operator=(const Executor &) = delete;
@ -123,9 +129,13 @@ class Executor {
// Because ParameterAggregator is not threadsafe, we have to create mutex for each ParameterAggregator so we can
// acquire lock before calling its method.
std::map<std::string, std::mutex> parameter_mutex_;
#ifdef ENABLE_ARMOUR
armour::CipherUnmask cipher_unmask_;
#endif
// The flag represents the unmasking status.
std::atomic<bool> unmasked_;
};
} // namespace server
} // namespace fl

View File

@ -90,7 +90,7 @@ void PullWeightKernel::PullWeight(const std::shared_ptr<FBBuilder> &fbb,
for (size_t i = 0; i < weights_names_fbs->size(); i++) {
weight_names.push_back(weights_names_fbs->Get(i)->str());
}
if (!executor_->IsWeightAggrDone(weight_names)) {
if (!executor_->IsWeightAggrDone(weight_names) || !executor_->unmasked()) {
++retry_count_;
std::string reason = "The aggregation for the weights is not done yet.";
BuildPullWeightRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter, feature_maps);

View File

@ -35,9 +35,11 @@ void ReconstructSecretsKernel::InitKernel(size_t required_cnt) {
return;
}
auto last_cnt_handler = [&](std::shared_ptr<ps::core::MessageHandler>) {
MS_LOG(INFO) << "start FinishIteration";
FinishIteration();
MS_LOG(INFO) << "end FinishIteration";
if (ps::PSContext::instance()->resetter_round() == ps::ResetterRound::kReconstructSeccrets) {
MS_LOG(INFO) << "start FinishIteration";
FinishIteration();
MS_LOG(INFO) << "end FinishIteration";
}
return;
};
auto first_cnt_handler = [&](std::shared_ptr<ps::core::MessageHandler>) { return; };
@ -146,6 +148,7 @@ void ReconstructSecretsKernel::OnLastCountEvent(const std::shared_ptr<ps::core::
std::this_thread::sleep_for(std::chrono::milliseconds(5));
}
MS_LOG(INFO) << "end unmask";
Executor::GetInstance().set_unmasked(true);
std::string worker_id = std::to_string(DistributedCountService::GetInstance().local_rank());
DistributedCountService::GetInstance().Count(name_unmask_, worker_id);
}
@ -157,6 +160,7 @@ bool ReconstructSecretsKernel::Reset() {
DistributedCountService::GetInstance().ResetCounter(name_);
DistributedCountService::GetInstance().ResetCounter(name_unmask_);
StopTimer();
Executor::GetInstance().set_unmasked(false);
cipher_reconstruct_.ClearReconstructSecrets();
return true;
}

View File

@ -174,8 +174,14 @@ bool ParameterAggregator::IsOptimizingDone() const { return optimizing_done_; }
bool ParameterAggregator::IsPullingDone() const { return pulling_done_; }
bool ParameterAggregator::requires_aggr() const { return requires_aggr_; }
bool ParameterAggregator::InitAggregationKernels(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
if (!JudgeRequiresAggr(cnode)) {
MS_LOG(WARNING) << "Aggregation for weight for kernel " << AnfAlgo::GetCNodeName(cnode) << " is not required.";
}
std::vector<std::string> aggr_kernel_names = SelectAggregationAlgorithm(cnode);
for (const std::string &name : aggr_kernel_names) {
auto aggr_kernel = kernel::AggregationKernelFactory::GetInstance().Create(name, cnode);
@ -340,6 +346,28 @@ std::vector<std::string> ParameterAggregator::SelectAggregationAlgorithm(const C
return aggregation_algorithm;
}
bool ParameterAggregator::JudgeRequiresAggr(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
std::string cnode_name = AnfAlgo::GetCNodeName(cnode);
if (kNameToIdxMap.count(cnode_name) == 0 || kNameToIdxMap.at(cnode_name).count("inputs") == 0 ||
kNameToIdxMap.at(cnode_name).at("inputs").count("weight") == 0) {
MS_LOG(EXCEPTION) << "Can't find index info of weight for kernel " << cnode_name;
return false;
}
size_t cnode_weight_idx = kNameToIdxMap.at(cnode_name).at("inputs").at("weight");
auto weight_node = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cnode, cnode_weight_idx), 0).first;
MS_EXCEPTION_IF_NULL(weight_node);
if (!weight_node->isa<Parameter>()) {
MS_LOG(EXCEPTION) << weight_node->fullname_with_scope() << " is not a parameter node.";
return false;
}
auto param_info = weight_node->cast<ParameterPtr>()->param_info();
MS_EXCEPTION_IF_NULL(param_info);
requires_aggr_ = param_info->requires_aggr();
return requires_aggr_;
}
template bool ParameterAggregator::AssignMemory(std::shared_ptr<kernel::OptimizerKernel> server_kernel,
const CNodePtr &cnode,
const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info,

View File

@ -57,7 +57,8 @@ class ParameterAggregator {
aggregation_done_(false),
optimizing_done_(false),
pulling_done_(true),
memory_register_(nullptr) {}
memory_register_(nullptr),
requires_aggr_(true) {}
~ParameterAggregator() = default;
// Initialize ParameterAggregator with a cnode. This cnode is normally a optimizer kernel for now.
@ -94,6 +95,9 @@ class ParameterAggregator {
bool IsOptimizingDone() const;
bool IsPullingDone() const;
// Return whether this parameter requires aggragation.
bool requires_aggr() const;
private:
// Initializing aggregation/optimizer kenerls based on the cnode. The reason of this is described in the file
// kernel/kernel_factory.h.
@ -118,6 +122,9 @@ class ParameterAggregator {
// configuration, etc.
std::vector<std::string> SelectAggregationAlgorithm(const CNodePtr &cnode);
// Judge whether the parameter needs to be aggregated.
bool JudgeRequiresAggr(const CNodePtr &cnode);
ServerMode server_mode_;
size_t required_push_count_;
size_t required_pull_count_;
@ -135,6 +142,9 @@ class ParameterAggregator {
// Here stores multiple pairs of server kernels to parameters of their Launch function.
std::vector<std::pair<std::shared_ptr<kernel::AggregationKernel>, KernelParams>> aggregation_kernel_parameters_;
std::vector<std::pair<std::shared_ptr<kernel::OptimizerKernel>, KernelParams>> optimizer_kernel_parameters_;
// Whether this parameter needs to be aggregated.
bool requires_aggr_;
};
} // namespace server
} // namespace fl

View File

@ -270,6 +270,7 @@ void PSContext::GenerateResetterRound() {
bool is_parameter_server_mode = false;
bool is_federated_learning_mode = false;
bool is_mixed_training_mode = false;
bool use_pairwise_encrypt = (encrypt_type_ == kPWEncryptType);
if (server_mode_ == kServerModePS) {
is_parameter_server_mode = true;
@ -285,7 +286,7 @@ void PSContext::GenerateResetterRound() {
binary_server_context = ((unsigned int)is_parameter_server_mode << 0) |
((unsigned int)is_federated_learning_mode << 1) |
((unsigned int)is_mixed_training_mode << 2) | ((unsigned int)secure_aggregation_ << 3);
((unsigned int)is_mixed_training_mode << 2) | ((unsigned int)use_pairwise_encrypt << 3);
if (kServerContextToResetRoundMap.count(binary_server_context) == 0) {
resetter_round_ = ResetterRound::kNoNeedToReset;
} else {

View File

@ -44,14 +44,13 @@ constexpr char kNotEncryptType[] = "NOT_ENCRYPT";
// 0: Server is in parameter server mode.
// 1: Server is in federated learning mode.
// 2: Server is in mixed training mode.
// 3: Server enables sucure aggregation.
// For example: 1010 stands for that the server is in federated learning mode and sucure aggregation is enabled.
// 3: Server enables pairwise encrypt algorithm.
// For example: 1010 stands for that the server is in federated learning mode and pairwise encrypt algorithm is enabled.
enum class ResetterRound { kNoNeedToReset, kUpdateModel, kReconstructSeccrets, kPushWeight };
const std::map<uint32_t, ResetterRound> kServerContextToResetRoundMap = {{0b0010, ResetterRound::kUpdateModel},
{0b1010, ResetterRound::kReconstructSeccrets},
{0b1100, ResetterRound::kPushWeight},
{0b0100, ResetterRound::kPushWeight},
{0b0100, ResetterRound::kUpdateModel}};
{0b0100, ResetterRound::kPushWeight}};
class PSContext {
public:

View File

@ -34,6 +34,7 @@ REGISTER_PYBIND_DEFINE(ParamInfo, ([](const py::module *m) {
.def_property("comm_fusion", &ParamInfo::comm_fusion, &ParamInfo::set_comm_fusion)
.def_property("cache_enable", &ParamInfo::cache_enable, &ParamInfo::set_cache_enable)
.def_property("cache_shape", &ParamInfo::cache_shape, &ParamInfo::set_cache_shape)
.def_property("requires_aggr", &ParamInfo::requires_aggr, &ParamInfo::set_requires_aggr)
.def(py::pickle(
[](const ParamInfo &p) { // __getstate__
return py::make_tuple(p.name(), p.requires_grad(), p.layerwise_parallel());

View File

@ -152,6 +152,7 @@ class Parameter(Tensor_):
self.is_param_ps = False
self.push_weight_to_server = False
self.pull_weight_from_server = False
self.requires_aggr = True
self._cast_type = None
self._unique = False
self.is_in_parallel = _is_in_parallel_mode()
@ -236,18 +237,22 @@ class Parameter(Tensor_):
self.init_in_server = init_in_server
self.param_info.init_in_server = init_in_server
def set_param_fl(self, push_to_server=False, pull_from_server=False):
def set_param_fl(self, push_to_server=False, pull_from_server=False, requires_aggr=True):
"""
Set the way of parameter and server interaction.
Args:
push_to_server (bool): Whether the parameter should be pushed to server. Default: False.
pull_from_server (bool): Whether the parameter should be pulled from server. Default: False.
requires_aggr (bool): Whether the parameter should be aggregated in the server. Default: True.
"""
if push_to_server:
self.push_weight_to_server = True
if pull_from_server:
self.pull_weight_from_server = True
if not requires_aggr:
self.requires_aggr = False
self.param_info.requires_aggr = False
@property
def inited_param(self):
@ -376,6 +381,7 @@ class Parameter(Tensor_):
x.is_param_ps = self.is_param_ps
x.init_in_server = self.init_in_server
x.cache_enable = self.cache_enable
x.requires_aggr = self.requires_aggr
if self.cache_shape:
x.cache_shape = self.cache_shape
if init != 'same':

View File

@ -72,6 +72,7 @@ class ParamInfo {
this->be_cloned_ = true;
this->be_cloned_index_.push_back(index);
clone->init_in_server_ = this->init_in_server_;
clone->requires_aggr_ = this->requires_aggr_;
clone->ClearParameter();
return clone;
}
@ -91,6 +92,9 @@ class ParamInfo {
void set_parameter(const ParameterPtr &parameter) { parameter_ = parameter; }
void ClearParameter() { parameter_ = nullptr; }
bool requires_aggr() const { return requires_aggr_; }
void set_requires_aggr(bool requires_aggr) { requires_aggr_ = requires_aggr; }
private:
std::string name_{"Parameter"};
bool requires_grad_{true};
@ -105,6 +109,7 @@ class ParamInfo {
bool cache_enable_{false};
std::vector<int64_t> cache_shape_;
ParameterPtr parameter_{nullptr};
bool requires_aggr_{true};
};
} // namespace mindspore
#endif // MINDSPORE_CORE_IR_PARAM_INFO_H_

View File

@ -1247,17 +1247,18 @@ class Cell(Cell_):
for param in params:
param.set_param_ps(init_in_server)
def set_param_fl(self, push_to_server=False, pull_from_server=False):
def set_param_fl(self, push_to_server=False, pull_from_server=False, requires_aggr=True):
"""
Set the way of parameter and server interaction.
Args:
push_to_server (bool): Whether the parameter should be pushed to server. Default: False.
pull_from_server (bool): Whether the parameter should be pulled from server. Default: False.
requires_aggr (bool): Whether the parameter should be aggregated in the server. Default: True.
"""
params = self.parameters_and_names()
for param in params:
param[1].set_param_fl(push_to_server, pull_from_server)
param[1].set_param_fl(push_to_server, pull_from_server, requires_aggr)
def set_comm_fusion(self, fusion_type, recurse=True):
"""