!16912 Optimize ps doc

From: @zpac
Reviewed-by: @limingqi107,@cristoval
Signed-off-by: @cristoval
This commit is contained in:
mindspore-ci-bot 2021-05-26 09:21:06 +08:00 committed by Gitee
commit c3f2257793
7 changed files with 14 additions and 29 deletions

View File

@ -359,8 +359,6 @@ PYBIND11_MODULE(_c_expression, m) {
.def("set_client_batch_size", &PSContext::set_client_batch_size, "Set federated learning client batch size.")
.def("set_client_learning_rate", &PSContext::set_client_learning_rate,
"Set federated learning client learning rate.")
.def("set_secure_aggregation", &PSContext::set_secure_aggregation,
"Set federated learning client using secure aggregation.")
.def("set_enable_ssl", &PSContext::enable_ssl, "Set PS SSL mode enabled or disabled.");
(void)py::class_<OpInfoLoaderPy, std::shared_ptr<OpInfoLoaderPy>>(m, "OpInfoLoaderPy")

View File

@ -304,10 +304,6 @@ void PSContext::set_client_learning_rate(float client_learning_rate) { client_le
float PSContext::client_learning_rate() const { return client_learning_rate_; }
void PSContext::set_secure_aggregation(bool secure_aggregation) { secure_aggregation_ = secure_aggregation; }
bool PSContext::secure_aggregation() const { return secure_aggregation_; }
bool PSContext::enable_ssl() const { return enable_ssl_; }
void PSContext::set_enable_ssl(bool enabled) { enable_ssl_ = enabled; }

View File

@ -144,10 +144,6 @@ class PSContext {
void set_client_learning_rate(float client_learning_rate);
float client_learning_rate() const;
// Set true if using secure aggregation for federated learning.
void set_secure_aggregation(bool secure_aggregation);
bool secure_aggregation() const;
core::ClusterConfig &cluster_config();
private:

View File

@ -740,15 +740,14 @@ def set_ps_context(**kwargs):
Some other environment variables should also be set for parameter server training mode.
These environment variables are listed below:
MS_SERVER_NUM # Server number
MS_WORKER_NUM # Worker number
MS_SCHED_HOST # Scheduler IP address
MS_SCHED_PORT # Scheduler port
MS_ROLE # The role of this process:
MS_SCHED #represents the scheduler,
MS_WORKER #represents the worker,
MS_PSERVER #represents the Server
- MS_SERVER_NUM: Server number
- MS_WORKER_NUM: Worker number
- MS_SCHED_HOST: Scheduler IP address
- MS_SCHED_PORT: Scheduler port
- MS_ROLE: The role of this process:
- MS_SCHED: represents the scheduler,
- MS_WORKER: represents the worker,
- MS_PSERVER: represents the Server
Args:
enable_ps (bool): Whether to enable parameter server training mode.
@ -771,6 +770,8 @@ def get_ps_context(attr_key):
Args:
attr_key (str): The key of the attribute.
- "enable_ps": Whether to enable parameter server training mode.
Returns:
Returns attribute value according to the key.
@ -818,7 +819,6 @@ def set_fl_context(**kwargs):
client_epoch_num (int): Client training epoch number. Default: 25.
client_batch_size (int): Client training data batch size. Default: 32.
client_learning_rate (float): Client training learning rate. Default: 0.001.
secure_aggregation (bool): Whether to use secure aggregation algorithm. Default: False.
Raises:
ValueError: If input key is not the attribute in federated learning mode context.
@ -835,11 +835,15 @@ def get_fl_context(attr_key):
Args:
attr_key (str): The key of the attribute.
Please refer to `set_fl_context`'s parameters to decide which key should passed.
Returns:
Returns attribute value according to the key.
Raises:
ValueError: If input key is not attribute in federated learning mode context.
Examples:
>>> context.get_fl_context("server_mode")
"""
return _get_ps_context(attr_key)

View File

@ -52,7 +52,6 @@ _set_ps_context_func_map = {
"client_epoch_num": ps_context().set_client_epoch_num,
"client_batch_size": ps_context().set_client_batch_size,
"client_learning_rate": ps_context().set_client_learning_rate,
"secure_aggregation": ps_context().set_secure_aggregation,
"enable_ps_ssl": ps_context().set_enable_ssl
}

View File

@ -13,7 +13,6 @@
# limitations under the License.
# ============================================================================
import ast
import argparse
import subprocess
@ -34,7 +33,6 @@ parser.add_argument("--fl_iteration_num", type=int, default=25)
parser.add_argument("--client_epoch_num", type=int, default=20)
parser.add_argument("--client_batch_size", type=int, default=32)
parser.add_argument("--client_learning_rate", type=float, default=0.1)
parser.add_argument("--secure_aggregation", type=ast.literal_eval, default=False)
parser.add_argument("--local_server_num", type=int, default=-1)
if __name__ == "__main__":
@ -55,7 +53,6 @@ if __name__ == "__main__":
client_epoch_num = args.client_epoch_num
client_batch_size = args.client_batch_size
client_learning_rate = args.client_learning_rate
secure_aggregation = args.secure_aggregation
local_server_num = args.local_server_num
if local_server_num == -1:
@ -86,7 +83,6 @@ if __name__ == "__main__":
cmd_server += " --client_epoch_num=" + str(client_epoch_num)
cmd_server += " --client_batch_size=" + str(client_batch_size)
cmd_server += " --client_learning_rate=" + str(client_learning_rate)
cmd_server += " --secure_aggregation=" + str(secure_aggregation)
cmd_server += " > server.log 2>&1 &"
import time

View File

@ -13,7 +13,6 @@
# limitations under the License.
# ============================================================================
import ast
import argparse
import numpy as np
@ -42,7 +41,6 @@ parser.add_argument("--fl_iteration_num", type=int, default=25)
parser.add_argument("--client_epoch_num", type=int, default=20)
parser.add_argument("--client_batch_size", type=int, default=32)
parser.add_argument("--client_learning_rate", type=float, default=0.1)
parser.add_argument("--secure_aggregation", type=ast.literal_eval, default=False)
args, _ = parser.parse_known_args()
device_target = args.device_target
@ -62,7 +60,6 @@ fl_iteration_num = args.fl_iteration_num
client_epoch_num = args.client_epoch_num
client_batch_size = args.client_batch_size
client_learning_rate = args.client_learning_rate
secure_aggregation = args.secure_aggregation
ctx = {
"enable_fl": True,
@ -82,7 +79,6 @@ ctx = {
"client_epoch_num": client_epoch_num,
"client_batch_size": client_batch_size,
"client_learning_rate": client_learning_rate,
"secure_aggregation": secure_aggregation
}
context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_graphs=False)