forked from mindspore-Ecosystem/mindspore
!31703 add parameters for fl compression
Merge pull request !31703 from wtcheng/master
This commit is contained in:
commit
8de324cfa5
|
@ -21,7 +21,7 @@ from time import time
|
|||
import numpy as np
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import save_checkpoint
|
||||
from src.adam import AdamWeightDecayOp as AdamWeightDecay
|
||||
from mindspore.nn import Adam as AdamWeightDecay
|
||||
from src.config import train_cfg, server_net_cfg
|
||||
from src.model import AlbertModelCLS
|
||||
from src.cell_wrapper import NetworkWithCLSLoss, NetworkTrainCell
|
||||
|
@ -82,6 +82,13 @@ def parse_args():
|
|||
parser.add_argument("--sign_thr_ratio", type=float, default=0.6)
|
||||
parser.add_argument("--sign_global_lr", type=float, default=0.1)
|
||||
parser.add_argument("--sign_dim_out", type=int, default=0)
|
||||
parser.add_argument("--global_iteration_time_window", type=int, default=3600000)
|
||||
# parameters for "compression"
|
||||
parser.add_argument("--upload_compress_type", type=str, default="NO_COMPRESS",
|
||||
choices=["NO_COMPRESS", "DIFF_SPARSE_QUANT"])
|
||||
parser.add_argument("--upload_sparse_rate", type=float, default=0.5)
|
||||
parser.add_argument("--download_compress_type", type=str, default="NO_COMPRESS",
|
||||
choices=["NO_COMPRESS", "QUANT"])
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
@ -129,6 +136,10 @@ def server_train(args):
|
|||
sign_thr_ratio = args.sign_thr_ratio
|
||||
sign_global_lr = args.sign_global_lr
|
||||
sign_dim_out = args.sign_dim_out
|
||||
global_iteration_time_window = args.global_iteration_time_window
|
||||
upload_compress_type = args.upload_compress_type
|
||||
upload_sparse_rate = args.upload_sparse_rate
|
||||
download_compress_type = args.download_compress_type
|
||||
|
||||
# Replace some parameters with federated learning parameters.
|
||||
train_cfg.max_global_epoch = fl_iteration_num
|
||||
|
@ -171,7 +182,11 @@ def server_train(args):
|
|||
"sign_eps": sign_eps,
|
||||
"sign_thr_ratio": sign_thr_ratio,
|
||||
"sign_global_lr": sign_global_lr,
|
||||
"sign_dim_out": sign_dim_out
|
||||
"sign_dim_out": sign_dim_out,
|
||||
"global_iteration_time_window": global_iteration_time_window,
|
||||
"upload_compress_type": upload_compress_type,
|
||||
"upload_sparse_rate": upload_sparse_rate,
|
||||
"download_compress_type": download_compress_type,
|
||||
}
|
||||
|
||||
if not os.path.exists(output_dir):
|
||||
|
|
|
@ -64,6 +64,13 @@ parser.add_argument("--sign_eps", type=float, default=100)
|
|||
parser.add_argument("--sign_thr_ratio", type=float, default=0.6)
|
||||
parser.add_argument("--sign_global_lr", type=float, default=0.1)
|
||||
parser.add_argument("--sign_dim_out", type=int, default=0)
|
||||
parser.add_argument("--global_iteration_time_window", type=int, default=3600000)
|
||||
# parameters for "compression"
|
||||
parser.add_argument("--upload_compress_type", type=str, default="NO_COMPRESS",
|
||||
choices=["NO_COMPRESS", "DIFF_SPARSE_QUANT"])
|
||||
parser.add_argument("--upload_sparse_rate", type=float, default=0.5)
|
||||
parser.add_argument("--download_compress_type", type=str, default="NO_COMPRESS",
|
||||
choices=["NO_COMPRESS", "QUANT"])
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
|
@ -104,6 +111,10 @@ sign_eps = args.sign_eps
|
|||
sign_thr_ratio = args.sign_thr_ratio
|
||||
sign_global_lr = args.sign_global_lr
|
||||
sign_dim_out = args.sign_dim_out
|
||||
global_iteration_time_window = args.global_iteration_time_window
|
||||
upload_compress_type = args.upload_compress_type
|
||||
upload_sparse_rate = args.upload_sparse_rate
|
||||
download_compress_type = args.download_compress_type
|
||||
|
||||
if local_server_num == -1:
|
||||
local_server_num = server_num
|
||||
|
@ -155,6 +166,10 @@ for i in range(local_server_num):
|
|||
cmd_server += " --sign_thr_ratio=" + str(sign_thr_ratio)
|
||||
cmd_server += " --sign_global_lr=" + str(sign_global_lr)
|
||||
cmd_server += " --sign_dim_out=" + str(sign_dim_out)
|
||||
cmd_server += " --global_iteration_time_window=" + str(global_iteration_time_window)
|
||||
cmd_server += " --upload_compress_type=" + str(upload_compress_type)
|
||||
cmd_server += " --upload_sparse_rate=" + str(upload_sparse_rate)
|
||||
cmd_server += " --download_compress_type=" + str(download_compress_type)
|
||||
cmd_server += " > server.log 2>&1 &"
|
||||
|
||||
import time
|
||||
|
|
|
@ -55,6 +55,13 @@ parser.add_argument("--sign_eps", type=float, default=100)
|
|||
parser.add_argument("--sign_thr_ratio", type=float, default=0.6)
|
||||
parser.add_argument("--sign_global_lr", type=float, default=0.1)
|
||||
parser.add_argument("--sign_dim_out", type=int, default=0)
|
||||
parser.add_argument("--global_iteration_time_window", type=int, default=3600000)
|
||||
# parameters for "compression"
|
||||
parser.add_argument("--upload_compress_type", type=str, default="NO_COMPRESS",
|
||||
choices=["NO_COMPRESS", "DIFF_SPARSE_QUANT"])
|
||||
parser.add_argument("--upload_sparse_rate", type=float, default=0.5)
|
||||
parser.add_argument("--download_compress_type", type=str, default="NO_COMPRESS",
|
||||
choices=["NO_COMPRESS", "QUANT"])
|
||||
|
||||
if __name__ == "__main__":
|
||||
args, _ = parser.parse_known_args()
|
||||
|
@ -91,6 +98,10 @@ if __name__ == "__main__":
|
|||
sign_thr_ratio = args.sign_thr_ratio
|
||||
sign_global_lr = args.sign_global_lr
|
||||
sign_dim_out = args.sign_dim_out
|
||||
global_iteration_time_window = args.global_iteration_time_window
|
||||
upload_compress_type = args.upload_compress_type
|
||||
upload_sparse_rate = args.upload_sparse_rate
|
||||
download_compress_type = args.download_compress_type
|
||||
|
||||
if local_server_num == -1:
|
||||
local_server_num = server_num
|
||||
|
@ -137,6 +148,10 @@ if __name__ == "__main__":
|
|||
cmd_server += " --sign_thr_ratio=" + str(sign_thr_ratio)
|
||||
cmd_server += " --sign_global_lr=" + str(sign_global_lr)
|
||||
cmd_server += " --sign_dim_out=" + str(sign_dim_out)
|
||||
cmd_server += " --global_iteration_time_window=" + str(global_iteration_time_window)
|
||||
cmd_server += " --upload_compress_type=" + str(upload_compress_type)
|
||||
cmd_server += " --upload_sparse_rate=" + str(upload_sparse_rate)
|
||||
cmd_server += " --download_compress_type=" + str(download_compress_type)
|
||||
cmd_server += " > server.log 2>&1 &"
|
||||
|
||||
import time
|
||||
|
|
|
@ -62,6 +62,13 @@ parser.add_argument("--sign_eps", type=float, default=100)
|
|||
parser.add_argument("--sign_thr_ratio", type=float, default=0.6)
|
||||
parser.add_argument("--sign_global_lr", type=float, default=0.1)
|
||||
parser.add_argument("--sign_dim_out", type=int, default=0)
|
||||
parser.add_argument("--global_iteration_time_window", type=int, default=3600000)
|
||||
# parameters for "compression"
|
||||
parser.add_argument("--upload_compress_type", type=str, default="NO_COMPRESS",
|
||||
choices=["NO_COMPRESS", "DIFF_SPARSE_QUANT"])
|
||||
parser.add_argument("--upload_sparse_rate", type=float, default=0.5)
|
||||
parser.add_argument("--download_compress_type", type=str, default="NO_COMPRESS",
|
||||
choices=["NO_COMPRESS", "QUANT"])
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
|
@ -98,6 +105,10 @@ sign_eps = args.sign_eps
|
|||
sign_thr_ratio = args.sign_thr_ratio
|
||||
sign_global_lr = args.sign_global_lr
|
||||
sign_dim_out = args.sign_dim_out
|
||||
global_iteration_time_window = args.global_iteration_time_window
|
||||
upload_compress_type = args.upload_compress_type
|
||||
upload_sparse_rate = args.upload_sparse_rate
|
||||
download_compress_type = args.download_compress_type
|
||||
|
||||
ctx = {
|
||||
"enable_fl": True,
|
||||
|
@ -133,7 +144,11 @@ ctx = {
|
|||
"sign_eps": sign_eps,
|
||||
"sign_thr_ratio": sign_thr_ratio,
|
||||
"sign_global_lr": sign_global_lr,
|
||||
"sign_dim_out": sign_dim_out
|
||||
"sign_dim_out": sign_dim_out,
|
||||
"global_iteration_time_window": global_iteration_time_window,
|
||||
"upload_compress_type": upload_compress_type,
|
||||
"upload_sparse_rate": upload_sparse_rate,
|
||||
"download_compress_type": download_compress_type,
|
||||
}
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
|
||||
|
|
|
@ -63,6 +63,13 @@ parser.add_argument("--sign_eps", type=float, default=100)
|
|||
parser.add_argument("--sign_thr_ratio", type=float, default=0.6)
|
||||
parser.add_argument("--sign_global_lr", type=float, default=0.1)
|
||||
parser.add_argument("--sign_dim_out", type=int, default=0)
|
||||
parser.add_argument("--global_iteration_time_window", type=int, default=3600000)
|
||||
# parameters for "compression"
|
||||
parser.add_argument("--upload_compress_type", type=str, default="NO_COMPRESS",
|
||||
choices=["NO_COMPRESS", "DIFF_SPARSE_QUANT"])
|
||||
parser.add_argument("--upload_sparse_rate", type=float, default=0.5)
|
||||
parser.add_argument("--download_compress_type", type=str, default="NO_COMPRESS",
|
||||
choices=["NO_COMPRESS", "QUANT"])
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
|
@ -103,6 +110,10 @@ sign_eps = args.sign_eps
|
|||
sign_thr_ratio = args.sign_thr_ratio
|
||||
sign_global_lr = args.sign_global_lr
|
||||
sign_dim_out = args.sign_dim_out
|
||||
global_iteration_time_window = args.global_iteration_time_window
|
||||
upload_compress_type = args.upload_compress_type
|
||||
upload_sparse_rate = args.upload_sparse_rate
|
||||
download_compress_type = args.download_compress_type
|
||||
|
||||
if local_server_num == -1:
|
||||
local_server_num = server_num
|
||||
|
@ -154,6 +165,10 @@ for i in range(local_server_num):
|
|||
cmd_server += " --sign_thr_ratio=" + str(sign_thr_ratio)
|
||||
cmd_server += " --sign_global_lr=" + str(sign_global_lr)
|
||||
cmd_server += " --sign_dim_out=" + str(sign_dim_out)
|
||||
cmd_server += " --global_iteration_time_window=" + str(global_iteration_time_window)
|
||||
cmd_server += " --upload_compress_type=" + str(upload_compress_type)
|
||||
cmd_server += " --upload_sparse_rate=" + str(upload_sparse_rate)
|
||||
cmd_server += " --download_compress_type=" + str(download_compress_type)
|
||||
cmd_server += " > server.log 2>&1 &"
|
||||
|
||||
import time
|
||||
|
|
|
@ -71,6 +71,13 @@ parser.add_argument("--sign_eps", type=float, default=100)
|
|||
parser.add_argument("--sign_thr_ratio", type=float, default=0.6)
|
||||
parser.add_argument("--sign_global_lr", type=float, default=0.1)
|
||||
parser.add_argument("--sign_dim_out", type=int, default=0)
|
||||
parser.add_argument("--global_iteration_time_window", type=int, default=3600000)
|
||||
# parameters for "compression"
|
||||
parser.add_argument("--upload_compress_type", type=str, default="NO_COMPRESS",
|
||||
choices=["NO_COMPRESS", "DIFF_SPARSE_QUANT"])
|
||||
parser.add_argument("--upload_sparse_rate", type=float, default=0.5)
|
||||
parser.add_argument("--download_compress_type", type=str, default="NO_COMPRESS",
|
||||
choices=["NO_COMPRESS", "QUANT"])
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
|
@ -113,6 +120,10 @@ sign_eps = args.sign_eps
|
|||
sign_thr_ratio = args.sign_thr_ratio
|
||||
sign_global_lr = args.sign_global_lr
|
||||
sign_dim_out = args.sign_dim_out
|
||||
global_iteration_time_window = args.global_iteration_time_window
|
||||
upload_compress_type = args.upload_compress_type
|
||||
upload_sparse_rate = args.upload_sparse_rate
|
||||
download_compress_type = args.download_compress_type
|
||||
|
||||
ctx = {
|
||||
"enable_fl": True,
|
||||
|
@ -154,7 +165,11 @@ ctx = {
|
|||
"sign_eps": sign_eps,
|
||||
"sign_thr_ratio": sign_thr_ratio,
|
||||
"sign_global_lr": sign_global_lr,
|
||||
"sign_dim_out": sign_dim_out
|
||||
"sign_dim_out": sign_dim_out,
|
||||
"global_iteration_time_window": global_iteration_time_window,
|
||||
"upload_compress_type": upload_compress_type,
|
||||
"upload_sparse_rate": upload_sparse_rate,
|
||||
"download_compress_type": download_compress_type,
|
||||
}
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
|
||||
|
|
|
@ -61,6 +61,13 @@ parser.add_argument("--sign_eps", type=float, default=100)
|
|||
parser.add_argument("--sign_thr_ratio", type=float, default=0.6)
|
||||
parser.add_argument("--sign_global_lr", type=float, default=0.1)
|
||||
parser.add_argument("--sign_dim_out", type=int, default=0)
|
||||
parser.add_argument("--global_iteration_time_window", type=int, default=3600000)
|
||||
# parameters for "compression"
|
||||
parser.add_argument("--upload_compress_type", type=str, default="NO_COMPRESS",
|
||||
choices=["NO_COMPRESS", "DIFF_SPARSE_QUANT"])
|
||||
parser.add_argument("--upload_sparse_rate", type=float, default=0.5)
|
||||
parser.add_argument("--download_compress_type", type=str, default="NO_COMPRESS",
|
||||
choices=["NO_COMPRESS", "QUANT"])
|
||||
|
||||
if __name__ == "__main__":
|
||||
args, _ = parser.parse_known_args()
|
||||
|
@ -102,6 +109,10 @@ if __name__ == "__main__":
|
|||
sign_thr_ratio = args.sign_thr_ratio
|
||||
sign_global_lr = args.sign_global_lr
|
||||
sign_dim_out = args.sign_dim_out
|
||||
global_iteration_time_window = args.global_iteration_time_window
|
||||
upload_compress_type = args.upload_compress_type
|
||||
upload_sparse_rate = args.upload_sparse_rate
|
||||
download_compress_type = args.download_compress_type
|
||||
|
||||
if local_server_num == -1:
|
||||
local_server_num = server_num
|
||||
|
@ -153,6 +164,10 @@ if __name__ == "__main__":
|
|||
cmd_server += " --sign_thr_ratio=" + str(sign_thr_ratio)
|
||||
cmd_server += " --sign_global_lr=" + str(sign_global_lr)
|
||||
cmd_server += " --sign_dim_out=" + str(sign_dim_out)
|
||||
cmd_server += " --global_iteration_time_window=" + str(global_iteration_time_window)
|
||||
cmd_server += " --upload_compress_type=" + str(upload_compress_type)
|
||||
cmd_server += " --upload_sparse_rate=" + str(upload_sparse_rate)
|
||||
cmd_server += " --download_compress_type=" + str(download_compress_type)
|
||||
cmd_server += " > server.log 2>&1 &"
|
||||
|
||||
import time
|
||||
|
|
|
@ -70,6 +70,13 @@ parser.add_argument("--sign_eps", type=float, default=100)
|
|||
parser.add_argument("--sign_thr_ratio", type=float, default=0.6)
|
||||
parser.add_argument("--sign_global_lr", type=float, default=0.1)
|
||||
parser.add_argument("--sign_dim_out", type=int, default=0)
|
||||
parser.add_argument("--global_iteration_time_window", type=int, default=3600000)
|
||||
# parameters for "compression"
|
||||
parser.add_argument("--upload_compress_type", type=str, default="NO_COMPRESS",
|
||||
choices=["NO_COMPRESS", "DIFF_SPARSE_QUANT"])
|
||||
parser.add_argument("--upload_sparse_rate", type=float, default=0.5)
|
||||
parser.add_argument("--download_compress_type", type=str, default="NO_COMPRESS",
|
||||
choices=["NO_COMPRESS", "QUANT"])
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
|
@ -111,6 +118,10 @@ sign_eps = args.sign_eps
|
|||
sign_thr_ratio = args.sign_thr_ratio
|
||||
sign_global_lr = args.sign_global_lr
|
||||
sign_dim_out = args.sign_dim_out
|
||||
global_iteration_time_window = args.global_iteration_time_window
|
||||
upload_compress_type = args.upload_compress_type
|
||||
upload_sparse_rate = args.upload_sparse_rate
|
||||
download_compress_type = args.download_compress_type
|
||||
|
||||
ctx = {
|
||||
"enable_fl": True,
|
||||
|
@ -151,7 +162,11 @@ ctx = {
|
|||
"sign_eps": sign_eps,
|
||||
"sign_thr_ratio": sign_thr_ratio,
|
||||
"sign_global_lr": sign_global_lr,
|
||||
"sign_dim_out": sign_dim_out
|
||||
"sign_dim_out": sign_dim_out,
|
||||
"global_iteration_time_window": global_iteration_time_window,
|
||||
"upload_compress_type": upload_compress_type,
|
||||
"upload_sparse_rate": upload_sparse_rate,
|
||||
"download_compress_type": download_compress_type,
|
||||
}
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
|
||||
|
|
Loading…
Reference in New Issue