add config file path

This commit is contained in:
w00517672 2021-09-23 14:42:07 +08:00
parent 2010d79336
commit c61b81ba05
3 changed files with 10 additions and 10 deletions

View File

@ -20,10 +20,9 @@ import ast
from time import time
import numpy as np
from mindspore import context, Tensor
from mindspore.train.serialization import save_checkpoint, load_checkpoint
from mindspore.train.serialization import save_checkpoint
from src.adam import AdamWeightDecayOp as AdamWeightDecay
from src.config import train_cfg, server_net_cfg
from src.utils import restore_params
from src.model import AlbertModelCLS
from src.cell_wrapper import NetworkWithCLSLoss, NetworkTrainCell
@ -67,6 +66,7 @@ def parse_args():
parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
parser.add_argument("--cipher_time_window", type=int, default=300000)
parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
parser.add_argument("--config_file_path", type=str, default="")
parser.add_argument("--client_password", type=str, default="")
parser.add_argument("--server_password", type=str, default="")
parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False)
@ -77,7 +77,6 @@ def server_train(args):
start = time()
os.environ['CUDA_VISIBLE_DEVICES'] = args.device_id
model_path = args.model_path
output_dir = args.output_dir
device_target = args.device_target
@ -104,6 +103,7 @@ def server_train(args):
share_secrets_ratio = args.share_secrets_ratio
cipher_time_window = args.cipher_time_window
reconstruct_secrets_threshold = args.reconstruct_secrets_threshold
config_file_path = args.config_file_path
client_password = args.client_password
server_password = args.server_password
enable_ssl = args.enable_ssl
@ -136,6 +136,7 @@ def server_train(args):
"share_secrets_ratio": share_secrets_ratio,
"cipher_time_window": cipher_time_window,
"reconstruct_secrets_threshold": reconstruct_secrets_threshold,
"config_file_path": config_file_path,
"client_password": client_password,
"server_password": server_password,
"enable_ssl": enable_ssl
@ -160,11 +161,6 @@ def server_train(args):
sys.stdout.flush()
start = time()
# train prepare
param_dict = load_checkpoint(model_path)
if 'learning_rate' in param_dict:
del param_dict['learning_rate']
# server optimizer
server_params = [_ for _ in network_with_cls_loss.trainable_params()]
server_decay_params = list(
@ -183,8 +179,6 @@ def server_train(args):
eps=train_cfg.optimizer_cfg.AdamWeightDecay.eps)
server_network_train_cell = NetworkTrainCell(network_with_cls_loss, optimizer=server_optimizer)
restore_params(server_network_train_cell, param_dict)
print('Optimizer construction is done! Time cost: {}'.format(time() - start))
sys.stdout.flush()
start = time()

View File

@ -29,6 +29,7 @@ parser.add_argument("--scheduler_manage_port", type=int, default=11202)
parser.add_argument("--client_password", type=str, default="")
parser.add_argument("--server_password", type=str, default="")
parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False)
parser.add_argument("--config_file_path", type=str, default="")
args, _ = parser.parse_known_args()
device_target = args.device_target
@ -41,6 +42,7 @@ scheduler_manage_port = args.scheduler_manage_port
client_password = args.client_password
server_password = args.server_password
enable_ssl = args.enable_ssl
config_file_path = args.config_file_path
os.environ['MS_NODE_ID'] = "20"
cmd_sched = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && rm -rf ${execute_path}/scheduler/ &&"
@ -58,6 +60,7 @@ cmd_sched += " --client_password=" + str(client_password)
cmd_sched += " --server_password=" + str(server_password)
cmd_sched += " --enable_ssl=" + str(enable_ssl)
cmd_sched += " --scheduler_manage_port=" + str(scheduler_manage_port)
cmd_sched += " --config_file_path=" + config_file_path
cmd_sched += " > scheduler.log 2>&1 &"
subprocess.call(['bash', '-c', cmd_sched])

View File

@ -49,6 +49,7 @@ parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
parser.add_argument("--client_password", type=str, default="")
parser.add_argument("--server_password", type=str, default="")
parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False)
parser.add_argument("--config_file_path", type=str, default="")
args, _ = parser.parse_known_args()
device_target = args.device_target
@ -78,6 +79,7 @@ reconstruct_secrets_threshold = args.reconstruct_secrets_threshold
client_password = args.client_password
server_password = args.server_password
enable_ssl = args.enable_ssl
config_file_path = args.config_file_path
if local_server_num == -1:
local_server_num = server_num
@ -118,6 +120,7 @@ for i in range(local_server_num):
cmd_server += " --server_password=" + str(server_password)
cmd_server += " --enable_ssl=" + str(enable_ssl)
cmd_server += " --reconstruct_secrets_threshold=" + str(reconstruct_secrets_threshold)
cmd_server += " --config_file_path=" + config_file_path
cmd_server += " > server.log 2>&1 &"
import time