forked from mindspore-Ecosystem/mindspore
!16912 Optimize ps doc
From: @zpac Reviewed-by: @limingqi107,@cristoval Signed-off-by: @cristoval
This commit is contained in:
commit
c3f2257793
|
@ -359,8 +359,6 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
.def("set_client_batch_size", &PSContext::set_client_batch_size, "Set federated learning client batch size.")
|
||||
.def("set_client_learning_rate", &PSContext::set_client_learning_rate,
|
||||
"Set federated learning client learning rate.")
|
||||
.def("set_secure_aggregation", &PSContext::set_secure_aggregation,
|
||||
"Set federated learning client using secure aggregation.")
|
||||
.def("set_enable_ssl", &PSContext::enable_ssl, "Set PS SSL mode enabled or disabled.");
|
||||
|
||||
(void)py::class_<OpInfoLoaderPy, std::shared_ptr<OpInfoLoaderPy>>(m, "OpInfoLoaderPy")
|
||||
|
|
|
@ -304,10 +304,6 @@ void PSContext::set_client_learning_rate(float client_learning_rate) { client_le
|
|||
|
||||
float PSContext::client_learning_rate() const { return client_learning_rate_; }
|
||||
|
||||
void PSContext::set_secure_aggregation(bool secure_aggregation) { secure_aggregation_ = secure_aggregation; }
|
||||
|
||||
bool PSContext::secure_aggregation() const { return secure_aggregation_; }
|
||||
|
||||
bool PSContext::enable_ssl() const { return enable_ssl_; }
|
||||
|
||||
void PSContext::set_enable_ssl(bool enabled) { enable_ssl_ = enabled; }
|
||||
|
|
|
@ -144,10 +144,6 @@ class PSContext {
|
|||
void set_client_learning_rate(float client_learning_rate);
|
||||
float client_learning_rate() const;
|
||||
|
||||
// Set true if using secure aggregation for federated learning.
|
||||
void set_secure_aggregation(bool secure_aggregation);
|
||||
bool secure_aggregation() const;
|
||||
|
||||
core::ClusterConfig &cluster_config();
|
||||
|
||||
private:
|
||||
|
|
|
@ -740,15 +740,14 @@ def set_ps_context(**kwargs):
|
|||
Some other environment variables should also be set for parameter server training mode.
|
||||
These environment variables are listed below:
|
||||
|
||||
MS_SERVER_NUM # Server number
|
||||
MS_WORKER_NUM # Worker number
|
||||
MS_SCHED_HOST # Scheduler IP address
|
||||
MS_SCHED_PORT # Scheduler port
|
||||
MS_ROLE # The role of this process:
|
||||
MS_SCHED #represents the scheduler,
|
||||
MS_WORKER #represents the worker,
|
||||
MS_PSERVER #represents the Server
|
||||
|
||||
- MS_SERVER_NUM: Server number
|
||||
- MS_WORKER_NUM: Worker number
|
||||
- MS_SCHED_HOST: Scheduler IP address
|
||||
- MS_SCHED_PORT: Scheduler port
|
||||
- MS_ROLE: The role of this process:
|
||||
- MS_SCHED: represents the scheduler,
|
||||
- MS_WORKER: represents the worker,
|
||||
- MS_PSERVER: represents the Server
|
||||
|
||||
Args:
|
||||
enable_ps (bool): Whether to enable parameter server training mode.
|
||||
|
@ -771,6 +770,8 @@ def get_ps_context(attr_key):
|
|||
Args:
|
||||
attr_key (str): The key of the attribute.
|
||||
|
||||
- "enable_ps": Whether to enable parameter server training mode.
|
||||
|
||||
Returns:
|
||||
Returns attribute value according to the key.
|
||||
|
||||
|
@ -818,7 +819,6 @@ def set_fl_context(**kwargs):
|
|||
client_epoch_num (int): Client training epoch number. Default: 25.
|
||||
client_batch_size (int): Client training data batch size. Default: 32.
|
||||
client_learning_rate (float): Client training learning rate. Default: 0.001.
|
||||
secure_aggregation (bool): Whether to use secure aggregation algorithm. Default: False.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not the attribute in federated learning mode context.
|
||||
|
@ -835,11 +835,15 @@ def get_fl_context(attr_key):
|
|||
|
||||
Args:
|
||||
attr_key (str): The key of the attribute.
|
||||
Please refer to `set_fl_context`'s parameters to decide which key should passed.
|
||||
|
||||
Returns:
|
||||
Returns attribute value according to the key.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not attribute in federated learning mode context.
|
||||
|
||||
Examples:
|
||||
>>> context.get_fl_context("server_mode")
|
||||
"""
|
||||
return _get_ps_context(attr_key)
|
||||
|
|
|
@ -52,7 +52,6 @@ _set_ps_context_func_map = {
|
|||
"client_epoch_num": ps_context().set_client_epoch_num,
|
||||
"client_batch_size": ps_context().set_client_batch_size,
|
||||
"client_learning_rate": ps_context().set_client_learning_rate,
|
||||
"secure_aggregation": ps_context().set_secure_aggregation,
|
||||
"enable_ps_ssl": ps_context().set_enable_ssl
|
||||
}
|
||||
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import ast
|
||||
import argparse
|
||||
import subprocess
|
||||
|
||||
|
@ -34,7 +33,6 @@ parser.add_argument("--fl_iteration_num", type=int, default=25)
|
|||
parser.add_argument("--client_epoch_num", type=int, default=20)
|
||||
parser.add_argument("--client_batch_size", type=int, default=32)
|
||||
parser.add_argument("--client_learning_rate", type=float, default=0.1)
|
||||
parser.add_argument("--secure_aggregation", type=ast.literal_eval, default=False)
|
||||
parser.add_argument("--local_server_num", type=int, default=-1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -55,7 +53,6 @@ if __name__ == "__main__":
|
|||
client_epoch_num = args.client_epoch_num
|
||||
client_batch_size = args.client_batch_size
|
||||
client_learning_rate = args.client_learning_rate
|
||||
secure_aggregation = args.secure_aggregation
|
||||
local_server_num = args.local_server_num
|
||||
|
||||
if local_server_num == -1:
|
||||
|
@ -86,7 +83,6 @@ if __name__ == "__main__":
|
|||
cmd_server += " --client_epoch_num=" + str(client_epoch_num)
|
||||
cmd_server += " --client_batch_size=" + str(client_batch_size)
|
||||
cmd_server += " --client_learning_rate=" + str(client_learning_rate)
|
||||
cmd_server += " --secure_aggregation=" + str(secure_aggregation)
|
||||
cmd_server += " > server.log 2>&1 &"
|
||||
|
||||
import time
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import ast
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
|
@ -42,7 +41,6 @@ parser.add_argument("--fl_iteration_num", type=int, default=25)
|
|||
parser.add_argument("--client_epoch_num", type=int, default=20)
|
||||
parser.add_argument("--client_batch_size", type=int, default=32)
|
||||
parser.add_argument("--client_learning_rate", type=float, default=0.1)
|
||||
parser.add_argument("--secure_aggregation", type=ast.literal_eval, default=False)
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
|
@ -62,7 +60,6 @@ fl_iteration_num = args.fl_iteration_num
|
|||
client_epoch_num = args.client_epoch_num
|
||||
client_batch_size = args.client_batch_size
|
||||
client_learning_rate = args.client_learning_rate
|
||||
secure_aggregation = args.secure_aggregation
|
||||
|
||||
ctx = {
|
||||
"enable_fl": True,
|
||||
|
@ -82,7 +79,6 @@ ctx = {
|
|||
"client_epoch_num": client_epoch_num,
|
||||
"client_batch_size": client_batch_size,
|
||||
"client_learning_rate": client_learning_rate,
|
||||
"secure_aggregation": secure_aggregation
|
||||
}
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_graphs=False)
|
||||
|
|
Loading…
Reference in New Issue