!31703 add parameters for fl compression

Merge pull request !31703 from wtcheng/master
This commit is contained in:
i-robot 2022-03-25 06:36:27 +00:00 committed by Gitee
commit 8de324cfa5
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
8 changed files with 125 additions and 5 deletions

View File

@ -21,7 +21,7 @@ from time import time
import numpy as np
from mindspore import context, Tensor
from mindspore.train.serialization import save_checkpoint
from src.adam import AdamWeightDecayOp as AdamWeightDecay
from mindspore.nn import Adam as AdamWeightDecay
from src.config import train_cfg, server_net_cfg
from src.model import AlbertModelCLS
from src.cell_wrapper import NetworkWithCLSLoss, NetworkTrainCell
@ -82,6 +82,13 @@ def parse_args():
parser.add_argument("--sign_thr_ratio", type=float, default=0.6)
parser.add_argument("--sign_global_lr", type=float, default=0.1)
parser.add_argument("--sign_dim_out", type=int, default=0)
parser.add_argument("--global_iteration_time_window", type=int, default=3600000)
# parameters for "compression"
parser.add_argument("--upload_compress_type", type=str, default="NO_COMPRESS",
choices=["NO_COMPRESS", "DIFF_SPARSE_QUANT"])
parser.add_argument("--upload_sparse_rate", type=float, default=0.5)
parser.add_argument("--download_compress_type", type=str, default="NO_COMPRESS",
choices=["NO_COMPRESS", "QUANT"])
return parser.parse_args()
@ -129,6 +136,10 @@ def server_train(args):
sign_thr_ratio = args.sign_thr_ratio
sign_global_lr = args.sign_global_lr
sign_dim_out = args.sign_dim_out
global_iteration_time_window = args.global_iteration_time_window
upload_compress_type = args.upload_compress_type
upload_sparse_rate = args.upload_sparse_rate
download_compress_type = args.download_compress_type
# Replace some parameters with federated learning parameters.
train_cfg.max_global_epoch = fl_iteration_num
@ -171,7 +182,11 @@ def server_train(args):
"sign_eps": sign_eps,
"sign_thr_ratio": sign_thr_ratio,
"sign_global_lr": sign_global_lr,
"sign_dim_out": sign_dim_out
"sign_dim_out": sign_dim_out,
"global_iteration_time_window": global_iteration_time_window,
"upload_compress_type": upload_compress_type,
"upload_sparse_rate": upload_sparse_rate,
"download_compress_type": download_compress_type,
}
if not os.path.exists(output_dir):

View File

@ -64,6 +64,13 @@ parser.add_argument("--sign_eps", type=float, default=100)
parser.add_argument("--sign_thr_ratio", type=float, default=0.6)
parser.add_argument("--sign_global_lr", type=float, default=0.1)
parser.add_argument("--sign_dim_out", type=int, default=0)
parser.add_argument("--global_iteration_time_window", type=int, default=3600000)
# parameters for "compression"
parser.add_argument("--upload_compress_type", type=str, default="NO_COMPRESS",
choices=["NO_COMPRESS", "DIFF_SPARSE_QUANT"])
parser.add_argument("--upload_sparse_rate", type=float, default=0.5)
parser.add_argument("--download_compress_type", type=str, default="NO_COMPRESS",
choices=["NO_COMPRESS", "QUANT"])
args, _ = parser.parse_known_args()
device_target = args.device_target
@ -104,6 +111,10 @@ sign_eps = args.sign_eps
sign_thr_ratio = args.sign_thr_ratio
sign_global_lr = args.sign_global_lr
sign_dim_out = args.sign_dim_out
global_iteration_time_window = args.global_iteration_time_window
upload_compress_type = args.upload_compress_type
upload_sparse_rate = args.upload_sparse_rate
download_compress_type = args.download_compress_type
if local_server_num == -1:
local_server_num = server_num
@ -155,6 +166,10 @@ for i in range(local_server_num):
cmd_server += " --sign_thr_ratio=" + str(sign_thr_ratio)
cmd_server += " --sign_global_lr=" + str(sign_global_lr)
cmd_server += " --sign_dim_out=" + str(sign_dim_out)
cmd_server += " --global_iteration_time_window=" + str(global_iteration_time_window)
cmd_server += " --upload_compress_type=" + str(upload_compress_type)
cmd_server += " --upload_sparse_rate=" + str(upload_sparse_rate)
cmd_server += " --download_compress_type=" + str(download_compress_type)
cmd_server += " > server.log 2>&1 &"
import time

View File

@ -55,6 +55,13 @@ parser.add_argument("--sign_eps", type=float, default=100)
parser.add_argument("--sign_thr_ratio", type=float, default=0.6)
parser.add_argument("--sign_global_lr", type=float, default=0.1)
parser.add_argument("--sign_dim_out", type=int, default=0)
parser.add_argument("--global_iteration_time_window", type=int, default=3600000)
# parameters for "compression"
parser.add_argument("--upload_compress_type", type=str, default="NO_COMPRESS",
choices=["NO_COMPRESS", "DIFF_SPARSE_QUANT"])
parser.add_argument("--upload_sparse_rate", type=float, default=0.5)
parser.add_argument("--download_compress_type", type=str, default="NO_COMPRESS",
choices=["NO_COMPRESS", "QUANT"])
if __name__ == "__main__":
args, _ = parser.parse_known_args()
@ -91,6 +98,10 @@ if __name__ == "__main__":
sign_thr_ratio = args.sign_thr_ratio
sign_global_lr = args.sign_global_lr
sign_dim_out = args.sign_dim_out
global_iteration_time_window = args.global_iteration_time_window
upload_compress_type = args.upload_compress_type
upload_sparse_rate = args.upload_sparse_rate
download_compress_type = args.download_compress_type
if local_server_num == -1:
local_server_num = server_num
@ -137,6 +148,10 @@ if __name__ == "__main__":
cmd_server += " --sign_thr_ratio=" + str(sign_thr_ratio)
cmd_server += " --sign_global_lr=" + str(sign_global_lr)
cmd_server += " --sign_dim_out=" + str(sign_dim_out)
cmd_server += " --global_iteration_time_window=" + str(global_iteration_time_window)
cmd_server += " --upload_compress_type=" + str(upload_compress_type)
cmd_server += " --upload_sparse_rate=" + str(upload_sparse_rate)
cmd_server += " --download_compress_type=" + str(download_compress_type)
cmd_server += " > server.log 2>&1 &"
import time

View File

@ -62,6 +62,13 @@ parser.add_argument("--sign_eps", type=float, default=100)
parser.add_argument("--sign_thr_ratio", type=float, default=0.6)
parser.add_argument("--sign_global_lr", type=float, default=0.1)
parser.add_argument("--sign_dim_out", type=int, default=0)
parser.add_argument("--global_iteration_time_window", type=int, default=3600000)
# parameters for "compression"
parser.add_argument("--upload_compress_type", type=str, default="NO_COMPRESS",
choices=["NO_COMPRESS", "DIFF_SPARSE_QUANT"])
parser.add_argument("--upload_sparse_rate", type=float, default=0.5)
parser.add_argument("--download_compress_type", type=str, default="NO_COMPRESS",
choices=["NO_COMPRESS", "QUANT"])
args, _ = parser.parse_known_args()
device_target = args.device_target
@ -98,6 +105,10 @@ sign_eps = args.sign_eps
sign_thr_ratio = args.sign_thr_ratio
sign_global_lr = args.sign_global_lr
sign_dim_out = args.sign_dim_out
global_iteration_time_window = args.global_iteration_time_window
upload_compress_type = args.upload_compress_type
upload_sparse_rate = args.upload_sparse_rate
download_compress_type = args.download_compress_type
ctx = {
"enable_fl": True,
@ -133,7 +144,11 @@ ctx = {
"sign_eps": sign_eps,
"sign_thr_ratio": sign_thr_ratio,
"sign_global_lr": sign_global_lr,
"sign_dim_out": sign_dim_out
"sign_dim_out": sign_dim_out,
"global_iteration_time_window": global_iteration_time_window,
"upload_compress_type": upload_compress_type,
"upload_sparse_rate": upload_sparse_rate,
"download_compress_type": download_compress_type,
}
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)

View File

@ -63,6 +63,13 @@ parser.add_argument("--sign_eps", type=float, default=100)
parser.add_argument("--sign_thr_ratio", type=float, default=0.6)
parser.add_argument("--sign_global_lr", type=float, default=0.1)
parser.add_argument("--sign_dim_out", type=int, default=0)
parser.add_argument("--global_iteration_time_window", type=int, default=3600000)
# parameters for "compression"
parser.add_argument("--upload_compress_type", type=str, default="NO_COMPRESS",
choices=["NO_COMPRESS", "DIFF_SPARSE_QUANT"])
parser.add_argument("--upload_sparse_rate", type=float, default=0.5)
parser.add_argument("--download_compress_type", type=str, default="NO_COMPRESS",
choices=["NO_COMPRESS", "QUANT"])
args, _ = parser.parse_known_args()
device_target = args.device_target
@ -103,6 +110,10 @@ sign_eps = args.sign_eps
sign_thr_ratio = args.sign_thr_ratio
sign_global_lr = args.sign_global_lr
sign_dim_out = args.sign_dim_out
global_iteration_time_window = args.global_iteration_time_window
upload_compress_type = args.upload_compress_type
upload_sparse_rate = args.upload_sparse_rate
download_compress_type = args.download_compress_type
if local_server_num == -1:
local_server_num = server_num
@ -154,6 +165,10 @@ for i in range(local_server_num):
cmd_server += " --sign_thr_ratio=" + str(sign_thr_ratio)
cmd_server += " --sign_global_lr=" + str(sign_global_lr)
cmd_server += " --sign_dim_out=" + str(sign_dim_out)
cmd_server += " --global_iteration_time_window=" + str(global_iteration_time_window)
cmd_server += " --upload_compress_type=" + str(upload_compress_type)
cmd_server += " --upload_sparse_rate=" + str(upload_sparse_rate)
cmd_server += " --download_compress_type=" + str(download_compress_type)
cmd_server += " > server.log 2>&1 &"
import time

View File

@ -71,6 +71,13 @@ parser.add_argument("--sign_eps", type=float, default=100)
parser.add_argument("--sign_thr_ratio", type=float, default=0.6)
parser.add_argument("--sign_global_lr", type=float, default=0.1)
parser.add_argument("--sign_dim_out", type=int, default=0)
parser.add_argument("--global_iteration_time_window", type=int, default=3600000)
# parameters for "compression"
parser.add_argument("--upload_compress_type", type=str, default="NO_COMPRESS",
choices=["NO_COMPRESS", "DIFF_SPARSE_QUANT"])
parser.add_argument("--upload_sparse_rate", type=float, default=0.5)
parser.add_argument("--download_compress_type", type=str, default="NO_COMPRESS",
choices=["NO_COMPRESS", "QUANT"])
args, _ = parser.parse_known_args()
device_target = args.device_target
@ -113,6 +120,10 @@ sign_eps = args.sign_eps
sign_thr_ratio = args.sign_thr_ratio
sign_global_lr = args.sign_global_lr
sign_dim_out = args.sign_dim_out
global_iteration_time_window = args.global_iteration_time_window
upload_compress_type = args.upload_compress_type
upload_sparse_rate = args.upload_sparse_rate
download_compress_type = args.download_compress_type
ctx = {
"enable_fl": True,
@ -154,7 +165,11 @@ ctx = {
"sign_eps": sign_eps,
"sign_thr_ratio": sign_thr_ratio,
"sign_global_lr": sign_global_lr,
"sign_dim_out": sign_dim_out
"sign_dim_out": sign_dim_out,
"global_iteration_time_window": global_iteration_time_window,
"upload_compress_type": upload_compress_type,
"upload_sparse_rate": upload_sparse_rate,
"download_compress_type": download_compress_type,
}
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)

View File

@ -61,6 +61,13 @@ parser.add_argument("--sign_eps", type=float, default=100)
parser.add_argument("--sign_thr_ratio", type=float, default=0.6)
parser.add_argument("--sign_global_lr", type=float, default=0.1)
parser.add_argument("--sign_dim_out", type=int, default=0)
parser.add_argument("--global_iteration_time_window", type=int, default=3600000)
# parameters for "compression"
parser.add_argument("--upload_compress_type", type=str, default="NO_COMPRESS",
choices=["NO_COMPRESS", "DIFF_SPARSE_QUANT"])
parser.add_argument("--upload_sparse_rate", type=float, default=0.5)
parser.add_argument("--download_compress_type", type=str, default="NO_COMPRESS",
choices=["NO_COMPRESS", "QUANT"])
if __name__ == "__main__":
args, _ = parser.parse_known_args()
@ -102,6 +109,10 @@ if __name__ == "__main__":
sign_thr_ratio = args.sign_thr_ratio
sign_global_lr = args.sign_global_lr
sign_dim_out = args.sign_dim_out
global_iteration_time_window = args.global_iteration_time_window
upload_compress_type = args.upload_compress_type
upload_sparse_rate = args.upload_sparse_rate
download_compress_type = args.download_compress_type
if local_server_num == -1:
local_server_num = server_num
@ -153,6 +164,10 @@ if __name__ == "__main__":
cmd_server += " --sign_thr_ratio=" + str(sign_thr_ratio)
cmd_server += " --sign_global_lr=" + str(sign_global_lr)
cmd_server += " --sign_dim_out=" + str(sign_dim_out)
cmd_server += " --global_iteration_time_window=" + str(global_iteration_time_window)
cmd_server += " --upload_compress_type=" + str(upload_compress_type)
cmd_server += " --upload_sparse_rate=" + str(upload_sparse_rate)
cmd_server += " --download_compress_type=" + str(download_compress_type)
cmd_server += " > server.log 2>&1 &"
import time

View File

@ -70,6 +70,13 @@ parser.add_argument("--sign_eps", type=float, default=100)
parser.add_argument("--sign_thr_ratio", type=float, default=0.6)
parser.add_argument("--sign_global_lr", type=float, default=0.1)
parser.add_argument("--sign_dim_out", type=int, default=0)
parser.add_argument("--global_iteration_time_window", type=int, default=3600000)
# parameters for "compression"
parser.add_argument("--upload_compress_type", type=str, default="NO_COMPRESS",
choices=["NO_COMPRESS", "DIFF_SPARSE_QUANT"])
parser.add_argument("--upload_sparse_rate", type=float, default=0.5)
parser.add_argument("--download_compress_type", type=str, default="NO_COMPRESS",
choices=["NO_COMPRESS", "QUANT"])
args, _ = parser.parse_known_args()
device_target = args.device_target
@ -111,6 +118,10 @@ sign_eps = args.sign_eps
sign_thr_ratio = args.sign_thr_ratio
sign_global_lr = args.sign_global_lr
sign_dim_out = args.sign_dim_out
global_iteration_time_window = args.global_iteration_time_window
upload_compress_type = args.upload_compress_type
upload_sparse_rate = args.upload_sparse_rate
download_compress_type = args.download_compress_type
ctx = {
"enable_fl": True,
@ -151,7 +162,11 @@ ctx = {
"sign_eps": sign_eps,
"sign_thr_ratio": sign_thr_ratio,
"sign_global_lr": sign_global_lr,
"sign_dim_out": sign_dim_out
"sign_dim_out": sign_dim_out,
"global_iteration_time_window": global_iteration_time_window,
"upload_compress_type": upload_compress_type,
"upload_sparse_rate": upload_sparse_rate,
"download_compress_type": download_compress_type,
}
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)