fl chinese doc

This commit is contained in:
w00517672 2022-05-26 17:08:00 +08:00
parent a29344460c
commit b47fd3fd91
4 changed files with 382 additions and 19 deletions

View File

@ -0,0 +1,42 @@
云云联邦学习
================================
.. py:class:: mindspore.FederatedLearningManager(model, sync_frequency, sync_type="fixed", **kwargs)
在训练过程中管理联邦学习。
**参数:**
- **model** (nn.Cell) - 一个用于联邦训练的模型。
- **sync_frequency** (int) - 联邦学习中的参数同步频率。
需要注意在数据下沉模式中频率的单位是epoch的数量。否则频率的单位是step的数量。
在自适应同步频率模式下为初始同步频率,在固定频率模式下为同步频率。
- **sync_type** (str) - 采用同步策略类型的参数。
支持["fixed", "adaptive"]。默认值:"fixed"。
- fixed参数的同步频率是固定的。
- adaptive参数的同步频率是自适应变化的。
- **min_consistent_rate** (float) - 最小一致性比率阈值,该值越大同步频率提升难度越大。
取值范围大于等于0.0。默认值1.1。
- **min_consistent_rate_at_round** (int) - 最小一致性比率阈值的轮数,该值越大同步频率提升难度越大。
取值范围大于等于0。默认值0。
- **ema_alpha** (float) - 梯度一致性平滑系数,该值越小越会根据当前轮次的梯度分叉情况来判断频率是否
需要改变,反之则会更加根据历史梯度分叉情况来判断。
取值范围:(0.0, 1.0)。默认值0.5。
- **observation_window_size** (int) - 观察时间窗的轮数,该值越大同步频率减小难度越大。
取值范围大于0。默认值5。
- **frequency_increase_ratio** (int) - 频率提升幅度,该值越大频率提升幅度越大。
取值范围大于0。默认值2。
- **unchanged_round** (int) - 频率不发生变化的轮数在前unchanged_round个轮次频率不会发生变化。
取值范围大于等于0。默认值0。
.. note::
这是一个实验原型,可能会有变化。
.. py:method:: step_end(run_context)
在step结束时同步参数。如果 `sync_type` 是"adaptive",同步频率会在这里自适应的调整。
**参数:**
- **run_context** (RunContext) - 包含模型的相关信息。

View File

@ -0,0 +1,190 @@
Federated-Server
======================
.. py:function:: mindspore.set_fl_context(**kwargs)
设置联邦学习训练模式的context。
.. note::
设置属性时,必须输入属性名称。
某些配置需要在特定角色上设置,有关详细信息,请参见下表:
+-------------------------+---------------------------------+-----------------------------+
| 功能分类 | 配置参数 | 联邦学习角色 |
+=========================+=================================+=============================+
| 组网配置 | enable_fl | scheduler/server/worker |
| +---------------------------------+-----------------------------+
| | server_mode | scheduler/server/worker |
| +---------------------------------+-----------------------------+
| | ms_role | scheduler/server/worker |
| +---------------------------------+-----------------------------+
| | worker_num | scheduler/server/worker |
| +---------------------------------+-----------------------------+
| | server_num | scheduler/server/worker |
| +---------------------------------+-----------------------------+
| | scheduler_ip | scheduler/server/worker |
| +---------------------------------+-----------------------------+
| | scheduler_port | scheduler/server/worker |
| +---------------------------------+-----------------------------+
| | fl_server_port | scheduler/server |
| +---------------------------------+-----------------------------+
| | scheduler_manage_port | scheduler |
+-------------------------+---------------------------------+-----------------------------+
| 训练配置 | start_fl_job_threshold | server |
| +---------------------------------+-----------------------------+
| | start_fl_job_time_window | server |
| +---------------------------------+-----------------------------+
| | update_model_ratio | server |
| +---------------------------------+-----------------------------+
| | update_model_time_window | server |
| +---------------------------------+-----------------------------+
| | fl_iteration_num | server |
| +---------------------------------+-----------------------------+
| | client_epoch_num | server |
| +---------------------------------+-----------------------------+
| | client_batch_size | server |
| +---------------------------------+-----------------------------+
| | client_learning_rate | server |
| +---------------------------------+-----------------------------+
| | worker_step_num_per_iteration | worker |
| +---------------------------------+-----------------------------+
| | global_iteration_time_window | server |
+-------------------------+---------------------------------+-----------------------------+
| 加密配置 | encrypt_type | server |
| +---------------------------------+-----------------------------+
| | share_secrets_ratio | server |
| +---------------------------------+-----------------------------+
| | cipher_time_window | server |
| +---------------------------------+-----------------------------+
| | reconstruct_secrets_threshold | server |
| +---------------------------------+-----------------------------+
| | dp_eps | server |
| +---------------------------------+-----------------------------+
| | dp_delta | server |
| +---------------------------------+-----------------------------+
| | dp_norm_clip | server |
| +---------------------------------+-----------------------------+
| | sign_k | server |
| +---------------------------------+-----------------------------+
| | sign_eps | server |
| +---------------------------------+-----------------------------+
| | sign_thr_ratio | server |
| +---------------------------------+-----------------------------+
| | sign_global_lr | server |
| +---------------------------------+-----------------------------+
| | sign_dim_out | server |
+-------------------------+---------------------------------+-----------------------------+
| 认证配置 | enable_ssl | server |
| +---------------------------------+-----------------------------+
| | client_password | server |
| +---------------------------------+-----------------------------+
| | server_password | server |
| +---------------------------------+-----------------------------+
| | root_first_ca_path | server |
| +---------------------------------+-----------------------------+
| | pki_verify | server |
| +---------------------------------+-----------------------------+
| | root_second_ca_path | server |
| +---------------------------------+-----------------------------+
| | equip_crl_path | server |
| +---------------------------------+-----------------------------+
| | replay_attack_time_diff | server |
| +---------------------------------+-----------------------------+
| | http_url_prefix | server |
+-------------------------+---------------------------------+-----------------------------+
| 容灾和运维配置 | fl_name | server |
| +---------------------------------+-----------------------------+
| | config_file_path | scheduler/server |
| +---------------------------------+-----------------------------+
| | checkpoint_dir | server |
+-------------------------+---------------------------------+-----------------------------+
| 压缩配置 | upload_compress_type | server |
| +---------------------------------+-----------------------------+
| | upload_sparse_rate | server |
| +---------------------------------+-----------------------------+
| | download_compress_type | server |
+-------------------------+---------------------------------+-----------------------------+
**参数:**
- **enable_fl** (bool) - 是否启用联邦学习训练模式。默认值False。
- **server_mode** (str) - 描述服务器模式,它必须是'FEDERATED_LEARNING'和'HYBRID_TRAINING'中的一个。
- **ms_role** (str) - 进程在联邦学习模式中的角色,
它必须是'MS_SERVER'、'MS_WORKER'和'MS_SCHED'中的一个。
- **worker_num** (int) - 云侧训练进程的数量。
- **server_num** (int) - 联邦学习服务器的数量。默认值0。
- **scheduler_ip** (str) - 调度器IP。默认值'0.0.0.0'。
- **scheduler_port** (int) - 调度器端口。默认值6667。
- **fl_server_port** (int) - 服务器端口。默认值6668。
- **start_fl_job_threshold** (int) - 开启联邦学习作业的阈值计数。默认值1。
- **start_fl_job_time_window** (int) - 开启联邦学习作业的时间窗口持续时间以毫秒为单位。默认值300000。
- **update_model_ratio** (float) - 计算更新模型阈值计数的比率。默认值1.0。
- **update_model_time_window** (int) - 更新模型的时间窗口持续时间以毫秒为单位。默认值300000。
- **fl_name** (str) - 联邦学习作业名称。默认值:""。
- **fl_iteration_num** (int) - 联邦学习的迭代次数即客户端和服务器的交互次数。默认值20。
- **client_epoch_num** (int) - 客户端训练epoch数量。默认值25。
- **client_batch_size** (int) - 客户端训练数据batch数。默认值32。
- **client_learning_rate** (float) - 客户端训练学习率。默认值0.001。
- **worker_step_num_per_iteration** (int) - 端云联邦中云侧训练进程在与服务器通信之前的独立训练步数。默认值65。
- **encrypt_type** (str) - 用于联邦学习的安全策略,可以是'NOT_ENCRYPT'、'DP_ENCRYPT'、
'PW_ENCRYPT'、'STABLE_PW_ENCRYPT'或'SIGNDS'。如果是'DP_ENCRYPT',则将对客户端应用差分隐私模式,
隐私保护效果将由上面所述的dp_eps、dp_delta、dp_norm_clip确定。如果'PW_ENCRYPT'则将应用成对pairwisePW安全聚合
来保护客户端模型在跨设备场景中不被窃取。如果'STABLE_PW_ENCRYPT',则将应用成对安全聚合来保护客户端模型在云云联邦场景中
免受窃取。如果'SIGNDS'则将在于客户端上使用SignDS策略。SignDS的介绍可以参照
`SignDS-FL: Local Differentially Private Federated Learning with Sign-based Dimension Selection <https://dl.acm.org/doi/abs/10.1145/3517820>`_
默认值:'NOT_ENCRYPT'。
- **share_secrets_ratio** (float) - PW参与秘密分享的客户端比例。默认值1.0。
- **cipher_time_window** (int) - PW每个加密轮次的时间窗口持续时间以毫秒为单位。默认值300000。
- **reconstruct_secrets_threshold** (int) - PW秘密重建的阈值。默认值2000。
- **dp_eps** (float) - DP差分隐私机制的epsilon预算。dp_eps越小隐私保护效果越好。默认值50.0。
- **dp_delta** (float) - DP差分隐私机制的delta预算通常等于客户端数量的倒数。dp_delta越小隐私保护效果越好。默认值0.01。
- **dp_norm_clip** (float) - DP差分隐私梯度裁剪的控制因子。建议其值为0.5~2。默认值1.0。
- **sign_k** (float) - SignDSTop-k比率即Top-k维度的数量除以维度总数。建议取值范围在(0, 0.25]内。默认值0.01。
- **sign_eps** (float) - SignDS隐私预算。该值越小隐私保护力度越大精度越低。建议取值范围在(0, 100]内。默认值100。
- **sign_thr_ratio** (float) - SignDS预期Top-k维度的阈值。建议取值范围在[0.5, 1]内。默认值0.6。
- **sign_global_lr** (float) - SignDS分配给选定维的常量值。适度增大该值会提高收敛速度但有可能让模型梯度爆炸。取值必须大于0。默认值1。
- **sign_dim_out** (int) - SignDS输出维度的数量。建议取值范围在[0, 50]内。默认值0。
- **config_file_path** (str) - 用于集群容灾恢复的配置文件路径、认证相关参数以及文件路径、评价指标文件路径和运维相关文件路径。默认值:""。
- **scheduler_manage_port** (int) - 用于扩容/缩容的调度器管理端口。默认值11202。
- **enable_ssl** (bool) - 设置联邦学习开启SSL安全通信。默认值False。
- **client_password** (str) - 解密客户端证书中存储的秘钥的密码。默认值:""。
- **server_password** (str) - 解密服务器证书中存储的秘钥的密码。默认值:""。
- **pki_verify** (bool) - 如果为True则将打开服务器和客户端之间的身份验证。
您还应从https://pki.consumer.huawei.com/ca/下载Root CA证书、Root CA G2证书和移动设备CRL证书。
需要注意的是只有当客户端是具有HUKS服务的Android环境时pki_verify可以为True。默认值False。
- **root_first_ca_path** (str) - Root CA证书的文件路径。当pki_verify为True时需要设置该值。默认值""。
- **root_second_ca_path** (str) - Root CA G2证书的文件路径。当pki_verify为True时需要设置该值。默认值""。
- **equip_crl_path** (str) - 移动设备CRL证书的文件路径。当pki_verify为True时需要设置该值。默认值""。
- **replay_attack_time_diff** (int) - 证书时间戳验证的最大可容忍错误毫秒。默认值600000。
- **http_url_prefix** (str) - 设置联邦学习端云通信的http路径。默认值""。
- **global_iteration_time_window** (int) - 一次迭代的全局时间窗口轮次ms。默认值3600000。
- **checkpoint_dir** (str) - server读取和保存模型文件的目录。若没有设置则不读取和保存模型文件。默认值""。
- **upload_compress_type** (str) - 上传压缩方法。可以是'NO_COMPRESS'或'DIFF_SPARSE_QUANT'。如果是'NO_COMPRESS',则不对上传的模型
进行压缩。如果是'DIFF_SPARSE_QUANT',则对上传的模型使用权重差+稀疏+量化压缩策略。默认值:'NO_COMPRESS'。
- **upload_sparse_rate** (float) - 上传压缩稀疏率。稀疏率越大,则压缩率越小。取值范围:(0, 1.0]。默认值0.4。
- **download_compress_type** (str) - 下载压缩方法。可以是'NO_COMPRESS'或'QUANT'。如果是'NO_COMPRESS',则不对下载的模型进行压缩。
如果是'QUANT',则对下载的模型使用量化压缩策略。默认值:'NO_COMPRESS'。
**异常:**
**ValueError** - 如果输入key不是联邦学习模式context中的属性。
.. py:function:: mindspore.get_fl_context(attr_key)
根据key获取联邦学习模式context中的属性值。
**参数:**
- **attr_key** (str) - 属性的key。
请参考 `set_fl_context` 中的参数来决定应传递的key。
**返回:**
根据key返回属性值。
**异常:**
**ValueError** - 如果输入key不是联邦学习模式context中的属性。

View File

@ -783,7 +783,7 @@ def set_context(**kwargs):
If enable_graph_kernel is set to True, acceleration can be enabled.
For details of graph kernel fusion, please check
`Enabling Graph Kernel Fusion
<https://www.mindspore.cn/docs/en/master/design/graph_fusion_engine.html>`_.
<https://www.mindspore.cn/tutorials/experts/en/master/debug/graph_fusion_engine.html>`_.
graph_kernel_flags (str)
Optimization options of graph kernel fusion, and the priority is higher when it conflicts
with enable_graph_kernel. Only for experienced users.
@ -1050,29 +1050,126 @@ def set_fl_context(**kwargs):
"""
Set federated learning training mode context.
Note:
Attribute name is required for setting attributes.
Some configurations need to be set on specific roles. For details, see the following table:
+-------------------------+---------------------------------+-----------------------------+
| Function Classification | Configuration Parameters | Federated Role |
+=========================+=================================+=============================+
| Network Configuration | enable_fl | scheduler/server/worker |
| +---------------------------------+-----------------------------+
| | server_mode | scheduler/server/worker |
| +---------------------------------+-----------------------------+
| | ms_role | scheduler/server/worker |
| +---------------------------------+-----------------------------+
| | worker_num | scheduler/server/worker |
| +---------------------------------+-----------------------------+
| | server_num | scheduler/server/worker |
| +---------------------------------+-----------------------------+
| | scheduler_ip | scheduler/server/worker |
| +---------------------------------+-----------------------------+
| | scheduler_port | scheduler/server/worker |
| +---------------------------------+-----------------------------+
| | fl_server_port | scheduler/server |
| +---------------------------------+-----------------------------+
| | scheduler_manage_port | scheduler |
+-------------------------+---------------------------------+-----------------------------+
| Training Configuration | start_fl_job_threshold | server |
| +---------------------------------+-----------------------------+
| | start_fl_job_time_window | server |
| +---------------------------------+-----------------------------+
| | update_model_ratio | server |
| +---------------------------------+-----------------------------+
| | update_model_time_window | server |
| +---------------------------------+-----------------------------+
| | fl_iteration_num | server |
| +---------------------------------+-----------------------------+
| | client_epoch_num | server |
| +---------------------------------+-----------------------------+
| | client_batch_size | server |
| +---------------------------------+-----------------------------+
| | client_learning_rate | server |
| +---------------------------------+-----------------------------+
| | worker_step_num_per_iteration | worker |
| +---------------------------------+-----------------------------+
| | global_iteration_time_window | server |
+-------------------------+---------------------------------+-----------------------------+
| Encrypt Configuration | encrypt_type | server |
| +---------------------------------+-----------------------------+
| | share_secrets_ratio | server |
| +---------------------------------+-----------------------------+
| | cipher_time_window | server |
| +---------------------------------+-----------------------------+
| | reconstruct_secrets_threshold | server |
| +---------------------------------+-----------------------------+
| | dp_eps | server |
| +---------------------------------+-----------------------------+
| | dp_delta | server |
| +---------------------------------+-----------------------------+
| | dp_norm_clip | server |
| +---------------------------------+-----------------------------+
| | sign_k | server |
| +---------------------------------+-----------------------------+
| | sign_eps | server |
| +---------------------------------+-----------------------------+
| | sign_thr_ratio | server |
| +---------------------------------+-----------------------------+
| | sign_global_lr | server |
| +---------------------------------+-----------------------------+
| | sign_dim_out | server |
+-------------------------+---------------------------------+-----------------------------+
| Cert Configuration | enable_ssl | server |
| +---------------------------------+-----------------------------+
| | client_password | server |
| +---------------------------------+-----------------------------+
| | server_password | server |
| +---------------------------------+-----------------------------+
| | root_first_ca_path | server |
| +---------------------------------+-----------------------------+
| | pki_verify | server |
| +---------------------------------+-----------------------------+
| | root_second_ca_path | server |
| +---------------------------------+-----------------------------+
| | equip_crl_path | server |
| +---------------------------------+-----------------------------+
| | replay_attack_time_diff | server |
| +---------------------------------+-----------------------------+
| | http_url_prefix | server |
+-------------------------+---------------------------------+-----------------------------+
| DR and OM Configuration | fl_name | server |
| +---------------------------------+-----------------------------+
| | config_file_path | scheduler/server |
| +---------------------------------+-----------------------------+
| | checkpoint_dir | server |
+-------------------------+---------------------------------+-----------------------------+
| Compress Configuration | upload_compress_type | server |
| +---------------------------------+-----------------------------+
| | upload_sparse_rate | server |
| +---------------------------------+-----------------------------+
| | download_compress_type | server |
+-------------------------+---------------------------------+-----------------------------+
Args:
enable_fl (bool): Whether to enable federated learning training mode.
Default: False.
server_mode (str): Describe the server mode, which must one of 'FEDERATED_LEARNING' and 'HYBRID_TRAINING'.
Default: 'FEDERATED_LEARNING'.
ms_role (str): The process's role in the federated learning mode,
which must be one of 'MS_SERVER', 'MS_WORKER' and 'MS_SCHED'.
Default: 'MS_SERVER'.
worker_num (int): The number of workers. For current version, this must be set to 1 or 0.
worker_num (int): The number of workers.
server_num (int): The number of federated learning servers. Default: 0.
scheduler_ip (str): The scheduler IP. Default: '0.0.0.0'.
scheduler_port (int): The scheduler port. Default: 6667.
fl_server_port (int): The http port of the federated learning server.
Normally for each server this should be set to the same value. Default: 6668.
enable_fl_client (bool): Whether this process is federated learning client. Default: False.
fl_server_port (int): The server port. Default: 6668.
start_fl_job_threshold (int): The threshold count of startFLJob. Default: 1.
start_fl_job_time_window (int): The time window duration for startFLJob in millisecond. Default: 3000.
start_fl_job_time_window (int): The time window duration for startFLJob in millisecond. Default: 300000.
share_secrets_ratio (float): The ratio for computing the threshold count of share secrets. Default: 1.0.
update_model_ratio (float): The ratio for computing the threshold count of updateModel. Default: 1.0.
cipher_time_window (int): The time window duration for each cipher round in millisecond. Default: 300000.
reconstruct_secrets_threshold (int): The threshold count of reconstruct threshold. Default: 0.
update_model_time_window (int): The time window duration for updateModel in millisecond. Default: 3000.
fl_name (string): The federated learning job name. Default: ''.
reconstruct_secrets_threshold (int): The threshold count of reconstruct threshold. Default: 2000.
update_model_time_window (int): The time window duration for updateModel in millisecond. Default: 300000.
fl_name (str): The federated learning job name. Default: ''.
fl_iteration_num (int): Iteration number of federated learning,
which is the number of interactions between client and server. Default: 20.
client_epoch_num (int): Client training epoch number. Default: 25.
@ -1086,7 +1183,7 @@ def set_fl_context(**kwargs):
client number. The smaller the dp_delta, the better the privacy protection effect. Default: 0.01.
dp_norm_clip (float): A factor used for clipping model's weights for differential mechanism. Its value is
suggested to be 0.5~2. Default: 1.0.
encrypt_type (string): Secure schema for federated learning, which can be 'NOT_ENCRYPT', 'DP_ENCRYPT',
encrypt_type (str): Secure schema for federated learning, which can be 'NOT_ENCRYPT', 'DP_ENCRYPT',
'PW_ENCRYPT', 'STABLE_PW_ENCRYPT' or 'SIGNDS'. If 'DP_ENCRYPT', differential privacy schema would be applied
for clients and the privacy protection effect would be determined by dp_eps, dp_delta and dp_norm_clip
as described above. If 'PW_ENCRYPT', pairwise secure aggregation would be applied to protect clients'
@ -1099,9 +1196,10 @@ def set_fl_context(**kwargs):
sign_thr_ratio (float): SignDS: Threshold of the expected topk dimension. Default: 0.6.
sign_global_lr (float): SignDS: The constant value assigned to the selected dimension. Default: 1.
sign_dim_out (int): SignDS: Number of output dimensions. Default: 0.
config_file_path (string): Configuration file path used by recovery. Default: ''.
config_file_path (str): Configuration file paths used for cluster disaster recovery, cert-related parameters
and file paths, metrics-related file path and O&M-related file paths. Default: ''.
scheduler_manage_port (int): scheduler manage port used to scale out/in. Default: 11202.
enable_ssl (bool): Set PS SSL mode enabled or disabled. Default: False.
enable_ssl (bool): Enable SSL secure communication for federated learning. Default: False.
client_password (str): Password to decrypt the secret key stored in the client certificate. Default: ''.
server_password (str): Password to decrypt the secret key stored in the server certificate. Default: ''.
pki_verify (bool): If True, the identity verification between server and clients would be turned on.
@ -1116,12 +1214,21 @@ def set_fl_context(**kwargs):
pki_verify is True. Default: "".
replay_attack_time_diff (int): The maximum tolerable error of certificate timestamp verification (ms).
Default: 600000.
http_url_prefix (string): The http url prefix for http server.
http_url_prefix (str): Setting the http path for federated learning cross-device communication.
Default: "".
global_iteration_time_window (unsigned long): The global iteration time window for one iteration
global_iteration_time_window (int): The global iteration time window for one iteration
with rounds(ms). Default: 3600000.
checkpoint_dir (string): The Server model checkpoint directory. If no checkpoint dir is set,
the startup script directory is used by default. Default: "".
checkpoint_dir (str): The directory where the server reads and saves model files. If it is not set,
the model file will not be read and saved. Default: "".
upload_compress_type (str): Upload compression method, which can be 'NO_COMPRESS' or 'DIFF_SPARSE_QUANT'.
If 'NO_COMPRESS', our framework does not compress the uploaded model. If 'DIFF_SPARSE_QUANT', the
weight difference + sparse + quantization compression strategy is used for the uploaded model.
Default: 'NO_COMPRESS'.
upload_sparse_rate (float): Upload compression sparsity ratio. The larger the sparsity ratio, the smaller the
compression ratio. Value range: (0, 1.0]. Default: 0.4.
download_compress_type (str): Download the compression method, which can be 'NO_COMPRESS' or 'QUANT'. If
'NO_COMPRESS', our framework does not compress the downloaded model. If'QUANT', the quantized compression
strategy is used for the downloaded model. Default: 'NO_COMPRESS'.
Raises:
ValueError: If input key is not the attribute in federated learning mode context.

View File

@ -94,6 +94,27 @@ class FederatedLearningManager(Callback):
- fixed: The frequency of parameter synchronization is fixed.
- adaptive: The frequency of parameter synchronization changes adaptively.
min_consistent_rate (float): Minimum consistency ratio threshold. The greater the value, the more
difficult it is to improve the synchronization frequency.
Value range: greater than or equal to 0.0. Default: 1.1.
min_consistent_rate_at_round (int): The number of rounds of the minimum consistency ratio threshold.
The greater the value, the more difficult it is to improve the
synchronization frequency.
Value range: greater than or equal to 0. Default: 0.
ema_alpha (float): Gradient consistency smoothing coefficient. The smaller the value, the more the
frequency will be judged according to the gradient bifurcation of the current round
more. Otherwise it will be judged according to the historical gradient bifurcation
more.
Value range: (0.0, 1.0). Default: 0.5.
observation_window_size (int): The number of rounds in the observation time window. The greater the
value, the more difficult it is to reduce the synchronization frequency.
Value range: greater than 0. Default: 5.
frequency_increase_ratio (int): Frequency increase amplitude. The greater the value, the greater the
frequency increase amplitude.
Value range: greater than 0. Default: 2.
unchanged_round (int): The number of rounds whose frequency does not change. The frequency is unchanged
before unchanged_round rounds.
Value range: greater than or equal to 0. Default: 0.
Note:
This is an experimental prototype that is subject to change.
@ -182,7 +203,10 @@ class FederatedLearningManager(Callback):
abs_grads[self._as_prefix + param.name] = np.abs(param.asnumpy() - self._last_param[param.name])
for param in self._model.trainable_params():
if self._as_prefix in param.name:
param.set_data(Parameter(abs_grads[param.name]))
try:
param.set_data(Parameter(abs_grads[param.name]))
except KeyError:
print("{} is not in abs_grads".format(param.name))
def _as_analyze_gradient(self):
"""