forked from mindspore-Ecosystem/mindspore
!21202 Fix hybrid training with pairwise encrypt
Merge pull request !21202 from ZPaC/fix-hybrid-with-pw-mask
This commit is contained in:
commit
04290b661a
|
@ -231,6 +231,9 @@ bool Executor::IsWeightAggrDone(const std::vector<std::string> ¶m_names) {
|
|||
std::unique_lock<std::mutex> lock(mtx);
|
||||
auto ¶m_aggr = param_aggrs_[name];
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false);
|
||||
if (!param_aggr->requires_aggr()) {
|
||||
continue;
|
||||
}
|
||||
if (!param_aggr->IsAggregationDone()) {
|
||||
MS_LOG(DEBUG) << "Update model for " << name << " is not done yet.";
|
||||
return false;
|
||||
|
@ -265,6 +268,8 @@ std::map<std::string, AddressPtr> Executor::GetModel() {
|
|||
return model;
|
||||
}
|
||||
|
||||
const std::vector<std::string> &Executor::param_names() const { return param_names_; }
|
||||
|
||||
bool Executor::Unmask() {
|
||||
#ifdef ENABLE_ARMOUR
|
||||
auto model = GetModel();
|
||||
|
@ -274,7 +279,17 @@ bool Executor::Unmask() {
|
|||
#endif
|
||||
}
|
||||
|
||||
const std::vector<std::string> &Executor::param_names() const { return param_names_; }
|
||||
void Executor::set_unmasked(bool unmasked) { unmasked_ = unmasked; }
|
||||
|
||||
bool Executor::unmasked() const {
|
||||
std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
|
||||
if (encrypt_type == ps::kPWEncryptType) {
|
||||
return unmasked_.load();
|
||||
} else {
|
||||
// If the algorithm of pairwise encrypt is not enabled, consider_ unmasked flag as true.
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
std::string Executor::GetTrainableParamName(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
|
|
@ -93,10 +93,16 @@ class Executor {
|
|||
bool initialized() const;
|
||||
|
||||
const std::vector<std::string> ¶m_names() const;
|
||||
|
||||
// The unmasking method for pairwise encrypt algorithm.
|
||||
bool Unmask();
|
||||
|
||||
// The setter and getter for unmasked flag to judge whether the unmasking is completed.
|
||||
void set_unmasked(bool unmasked);
|
||||
bool unmasked() const;
|
||||
|
||||
private:
|
||||
Executor() : initialized_(false), aggregation_count_(0), param_names_({}), param_aggrs_({}) {}
|
||||
Executor() : initialized_(false), aggregation_count_(0), param_names_({}), param_aggrs_({}), unmasked_(false) {}
|
||||
~Executor() = default;
|
||||
Executor(const Executor &) = delete;
|
||||
Executor &operator=(const Executor &) = delete;
|
||||
|
@ -123,9 +129,13 @@ class Executor {
|
|||
// Because ParameterAggregator is not threadsafe, we have to create mutex for each ParameterAggregator so we can
|
||||
// acquire lock before calling its method.
|
||||
std::map<std::string, std::mutex> parameter_mutex_;
|
||||
|
||||
#ifdef ENABLE_ARMOUR
|
||||
armour::CipherUnmask cipher_unmask_;
|
||||
#endif
|
||||
|
||||
// The flag represents the unmasking status.
|
||||
std::atomic<bool> unmasked_;
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace fl
|
||||
|
|
|
@ -90,7 +90,7 @@ void PullWeightKernel::PullWeight(const std::shared_ptr<FBBuilder> &fbb,
|
|||
for (size_t i = 0; i < weights_names_fbs->size(); i++) {
|
||||
weight_names.push_back(weights_names_fbs->Get(i)->str());
|
||||
}
|
||||
if (!executor_->IsWeightAggrDone(weight_names)) {
|
||||
if (!executor_->IsWeightAggrDone(weight_names) || !executor_->unmasked()) {
|
||||
++retry_count_;
|
||||
std::string reason = "The aggregation for the weights is not done yet.";
|
||||
BuildPullWeightRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter, feature_maps);
|
||||
|
|
|
@ -35,9 +35,11 @@ void ReconstructSecretsKernel::InitKernel(size_t required_cnt) {
|
|||
return;
|
||||
}
|
||||
auto last_cnt_handler = [&](std::shared_ptr<ps::core::MessageHandler>) {
|
||||
if (ps::PSContext::instance()->resetter_round() == ps::ResetterRound::kReconstructSeccrets) {
|
||||
MS_LOG(INFO) << "start FinishIteration";
|
||||
FinishIteration();
|
||||
MS_LOG(INFO) << "end FinishIteration";
|
||||
}
|
||||
return;
|
||||
};
|
||||
auto first_cnt_handler = [&](std::shared_ptr<ps::core::MessageHandler>) { return; };
|
||||
|
@ -146,6 +148,7 @@ void ReconstructSecretsKernel::OnLastCountEvent(const std::shared_ptr<ps::core::
|
|||
std::this_thread::sleep_for(std::chrono::milliseconds(5));
|
||||
}
|
||||
MS_LOG(INFO) << "end unmask";
|
||||
Executor::GetInstance().set_unmasked(true);
|
||||
std::string worker_id = std::to_string(DistributedCountService::GetInstance().local_rank());
|
||||
DistributedCountService::GetInstance().Count(name_unmask_, worker_id);
|
||||
}
|
||||
|
@ -157,6 +160,7 @@ bool ReconstructSecretsKernel::Reset() {
|
|||
DistributedCountService::GetInstance().ResetCounter(name_);
|
||||
DistributedCountService::GetInstance().ResetCounter(name_unmask_);
|
||||
StopTimer();
|
||||
Executor::GetInstance().set_unmasked(false);
|
||||
cipher_reconstruct_.ClearReconstructSecrets();
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -174,8 +174,14 @@ bool ParameterAggregator::IsOptimizingDone() const { return optimizing_done_; }
|
|||
|
||||
bool ParameterAggregator::IsPullingDone() const { return pulling_done_; }
|
||||
|
||||
bool ParameterAggregator::requires_aggr() const { return requires_aggr_; }
|
||||
|
||||
bool ParameterAggregator::InitAggregationKernels(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (!JudgeRequiresAggr(cnode)) {
|
||||
MS_LOG(WARNING) << "Aggregation for weight for kernel " << AnfAlgo::GetCNodeName(cnode) << " is not required.";
|
||||
}
|
||||
|
||||
std::vector<std::string> aggr_kernel_names = SelectAggregationAlgorithm(cnode);
|
||||
for (const std::string &name : aggr_kernel_names) {
|
||||
auto aggr_kernel = kernel::AggregationKernelFactory::GetInstance().Create(name, cnode);
|
||||
|
@ -340,6 +346,28 @@ std::vector<std::string> ParameterAggregator::SelectAggregationAlgorithm(const C
|
|||
return aggregation_algorithm;
|
||||
}
|
||||
|
||||
bool ParameterAggregator::JudgeRequiresAggr(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
std::string cnode_name = AnfAlgo::GetCNodeName(cnode);
|
||||
if (kNameToIdxMap.count(cnode_name) == 0 || kNameToIdxMap.at(cnode_name).count("inputs") == 0 ||
|
||||
kNameToIdxMap.at(cnode_name).at("inputs").count("weight") == 0) {
|
||||
MS_LOG(EXCEPTION) << "Can't find index info of weight for kernel " << cnode_name;
|
||||
return false;
|
||||
}
|
||||
size_t cnode_weight_idx = kNameToIdxMap.at(cnode_name).at("inputs").at("weight");
|
||||
auto weight_node = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cnode, cnode_weight_idx), 0).first;
|
||||
MS_EXCEPTION_IF_NULL(weight_node);
|
||||
|
||||
if (!weight_node->isa<Parameter>()) {
|
||||
MS_LOG(EXCEPTION) << weight_node->fullname_with_scope() << " is not a parameter node.";
|
||||
return false;
|
||||
}
|
||||
auto param_info = weight_node->cast<ParameterPtr>()->param_info();
|
||||
MS_EXCEPTION_IF_NULL(param_info);
|
||||
requires_aggr_ = param_info->requires_aggr();
|
||||
return requires_aggr_;
|
||||
}
|
||||
|
||||
template bool ParameterAggregator::AssignMemory(std::shared_ptr<kernel::OptimizerKernel> server_kernel,
|
||||
const CNodePtr &cnode,
|
||||
const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info,
|
||||
|
|
|
@ -57,7 +57,8 @@ class ParameterAggregator {
|
|||
aggregation_done_(false),
|
||||
optimizing_done_(false),
|
||||
pulling_done_(true),
|
||||
memory_register_(nullptr) {}
|
||||
memory_register_(nullptr),
|
||||
requires_aggr_(true) {}
|
||||
~ParameterAggregator() = default;
|
||||
|
||||
// Initialize ParameterAggregator with a cnode. This cnode is normally a optimizer kernel for now.
|
||||
|
@ -94,6 +95,9 @@ class ParameterAggregator {
|
|||
bool IsOptimizingDone() const;
|
||||
bool IsPullingDone() const;
|
||||
|
||||
// Return whether this parameter requires aggragation.
|
||||
bool requires_aggr() const;
|
||||
|
||||
private:
|
||||
// Initializing aggregation/optimizer kenerls based on the cnode. The reason of this is described in the file
|
||||
// kernel/kernel_factory.h.
|
||||
|
@ -118,6 +122,9 @@ class ParameterAggregator {
|
|||
// configuration, etc.
|
||||
std::vector<std::string> SelectAggregationAlgorithm(const CNodePtr &cnode);
|
||||
|
||||
// Judge whether the parameter needs to be aggregated.
|
||||
bool JudgeRequiresAggr(const CNodePtr &cnode);
|
||||
|
||||
ServerMode server_mode_;
|
||||
size_t required_push_count_;
|
||||
size_t required_pull_count_;
|
||||
|
@ -135,6 +142,9 @@ class ParameterAggregator {
|
|||
// Here stores multiple pairs of server kernels to parameters of their Launch function.
|
||||
std::vector<std::pair<std::shared_ptr<kernel::AggregationKernel>, KernelParams>> aggregation_kernel_parameters_;
|
||||
std::vector<std::pair<std::shared_ptr<kernel::OptimizerKernel>, KernelParams>> optimizer_kernel_parameters_;
|
||||
|
||||
// Whether this parameter needs to be aggregated.
|
||||
bool requires_aggr_;
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace fl
|
||||
|
|
|
@ -270,6 +270,7 @@ void PSContext::GenerateResetterRound() {
|
|||
bool is_parameter_server_mode = false;
|
||||
bool is_federated_learning_mode = false;
|
||||
bool is_mixed_training_mode = false;
|
||||
bool use_pairwise_encrypt = (encrypt_type_ == kPWEncryptType);
|
||||
|
||||
if (server_mode_ == kServerModePS) {
|
||||
is_parameter_server_mode = true;
|
||||
|
@ -285,7 +286,7 @@ void PSContext::GenerateResetterRound() {
|
|||
|
||||
binary_server_context = ((unsigned int)is_parameter_server_mode << 0) |
|
||||
((unsigned int)is_federated_learning_mode << 1) |
|
||||
((unsigned int)is_mixed_training_mode << 2) | ((unsigned int)secure_aggregation_ << 3);
|
||||
((unsigned int)is_mixed_training_mode << 2) | ((unsigned int)use_pairwise_encrypt << 3);
|
||||
if (kServerContextToResetRoundMap.count(binary_server_context) == 0) {
|
||||
resetter_round_ = ResetterRound::kNoNeedToReset;
|
||||
} else {
|
||||
|
|
|
@ -44,14 +44,13 @@ constexpr char kNotEncryptType[] = "NOT_ENCRYPT";
|
|||
// 0: Server is in parameter server mode.
|
||||
// 1: Server is in federated learning mode.
|
||||
// 2: Server is in mixed training mode.
|
||||
// 3: Server enables sucure aggregation.
|
||||
// For example: 1010 stands for that the server is in federated learning mode and sucure aggregation is enabled.
|
||||
// 3: Server enables pairwise encrypt algorithm.
|
||||
// For example: 1010 stands for that the server is in federated learning mode and pairwise encrypt algorithm is enabled.
|
||||
enum class ResetterRound { kNoNeedToReset, kUpdateModel, kReconstructSeccrets, kPushWeight };
|
||||
const std::map<uint32_t, ResetterRound> kServerContextToResetRoundMap = {{0b0010, ResetterRound::kUpdateModel},
|
||||
{0b1010, ResetterRound::kReconstructSeccrets},
|
||||
{0b1100, ResetterRound::kPushWeight},
|
||||
{0b0100, ResetterRound::kPushWeight},
|
||||
{0b0100, ResetterRound::kUpdateModel}};
|
||||
{0b0100, ResetterRound::kPushWeight}};
|
||||
|
||||
class PSContext {
|
||||
public:
|
||||
|
|
|
@ -34,6 +34,7 @@ REGISTER_PYBIND_DEFINE(ParamInfo, ([](const py::module *m) {
|
|||
.def_property("comm_fusion", &ParamInfo::comm_fusion, &ParamInfo::set_comm_fusion)
|
||||
.def_property("cache_enable", &ParamInfo::cache_enable, &ParamInfo::set_cache_enable)
|
||||
.def_property("cache_shape", &ParamInfo::cache_shape, &ParamInfo::set_cache_shape)
|
||||
.def_property("requires_aggr", &ParamInfo::requires_aggr, &ParamInfo::set_requires_aggr)
|
||||
.def(py::pickle(
|
||||
[](const ParamInfo &p) { // __getstate__
|
||||
return py::make_tuple(p.name(), p.requires_grad(), p.layerwise_parallel());
|
||||
|
|
|
@ -152,6 +152,7 @@ class Parameter(Tensor_):
|
|||
self.is_param_ps = False
|
||||
self.push_weight_to_server = False
|
||||
self.pull_weight_from_server = False
|
||||
self.requires_aggr = True
|
||||
self._cast_type = None
|
||||
self._unique = False
|
||||
self.is_in_parallel = _is_in_parallel_mode()
|
||||
|
@ -236,18 +237,22 @@ class Parameter(Tensor_):
|
|||
self.init_in_server = init_in_server
|
||||
self.param_info.init_in_server = init_in_server
|
||||
|
||||
def set_param_fl(self, push_to_server=False, pull_from_server=False):
|
||||
def set_param_fl(self, push_to_server=False, pull_from_server=False, requires_aggr=True):
|
||||
"""
|
||||
Set the way of parameter and server interaction.
|
||||
|
||||
Args:
|
||||
push_to_server (bool): Whether the parameter should be pushed to server. Default: False.
|
||||
pull_from_server (bool): Whether the parameter should be pulled from server. Default: False.
|
||||
requires_aggr (bool): Whether the parameter should be aggregated in the server. Default: True.
|
||||
"""
|
||||
if push_to_server:
|
||||
self.push_weight_to_server = True
|
||||
if pull_from_server:
|
||||
self.pull_weight_from_server = True
|
||||
if not requires_aggr:
|
||||
self.requires_aggr = False
|
||||
self.param_info.requires_aggr = False
|
||||
|
||||
@property
|
||||
def inited_param(self):
|
||||
|
@ -376,6 +381,7 @@ class Parameter(Tensor_):
|
|||
x.is_param_ps = self.is_param_ps
|
||||
x.init_in_server = self.init_in_server
|
||||
x.cache_enable = self.cache_enable
|
||||
x.requires_aggr = self.requires_aggr
|
||||
if self.cache_shape:
|
||||
x.cache_shape = self.cache_shape
|
||||
if init != 'same':
|
||||
|
|
|
@ -72,6 +72,7 @@ class ParamInfo {
|
|||
this->be_cloned_ = true;
|
||||
this->be_cloned_index_.push_back(index);
|
||||
clone->init_in_server_ = this->init_in_server_;
|
||||
clone->requires_aggr_ = this->requires_aggr_;
|
||||
clone->ClearParameter();
|
||||
return clone;
|
||||
}
|
||||
|
@ -91,6 +92,9 @@ class ParamInfo {
|
|||
void set_parameter(const ParameterPtr ¶meter) { parameter_ = parameter; }
|
||||
void ClearParameter() { parameter_ = nullptr; }
|
||||
|
||||
bool requires_aggr() const { return requires_aggr_; }
|
||||
void set_requires_aggr(bool requires_aggr) { requires_aggr_ = requires_aggr; }
|
||||
|
||||
private:
|
||||
std::string name_{"Parameter"};
|
||||
bool requires_grad_{true};
|
||||
|
@ -105,6 +109,7 @@ class ParamInfo {
|
|||
bool cache_enable_{false};
|
||||
std::vector<int64_t> cache_shape_;
|
||||
ParameterPtr parameter_{nullptr};
|
||||
bool requires_aggr_{true};
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_IR_PARAM_INFO_H_
|
||||
|
|
|
@ -1247,17 +1247,18 @@ class Cell(Cell_):
|
|||
for param in params:
|
||||
param.set_param_ps(init_in_server)
|
||||
|
||||
def set_param_fl(self, push_to_server=False, pull_from_server=False):
|
||||
def set_param_fl(self, push_to_server=False, pull_from_server=False, requires_aggr=True):
|
||||
"""
|
||||
Set the way of parameter and server interaction.
|
||||
|
||||
Args:
|
||||
push_to_server (bool): Whether the parameter should be pushed to server. Default: False.
|
||||
pull_from_server (bool): Whether the parameter should be pulled from server. Default: False.
|
||||
requires_aggr (bool): Whether the parameter should be aggregated in the server. Default: True.
|
||||
"""
|
||||
params = self.parameters_and_names()
|
||||
for param in params:
|
||||
param[1].set_param_fl(push_to_server, pull_from_server)
|
||||
param[1].set_param_fl(push_to_server, pull_from_server, requires_aggr)
|
||||
|
||||
def set_comm_fusion(self, fusion_type, recurse=True):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue