!23986 add config file path

Merge pull request !23986 from wtcheng/master
This commit is contained in:
i-robot 2021-09-24 07:15:59 +00:00 committed by Gitee
commit 3e1ef7823a
3 changed files with 10 additions and 10 deletions

View File

@ -20,10 +20,9 @@ import ast
from time import time from time import time
import numpy as np import numpy as np
from mindspore import context, Tensor from mindspore import context, Tensor
from mindspore.train.serialization import save_checkpoint, load_checkpoint from mindspore.train.serialization import save_checkpoint
from src.adam import AdamWeightDecayOp as AdamWeightDecay from src.adam import AdamWeightDecayOp as AdamWeightDecay
from src.config import train_cfg, server_net_cfg from src.config import train_cfg, server_net_cfg
from src.utils import restore_params
from src.model import AlbertModelCLS from src.model import AlbertModelCLS
from src.cell_wrapper import NetworkWithCLSLoss, NetworkTrainCell from src.cell_wrapper import NetworkWithCLSLoss, NetworkTrainCell
@ -67,6 +66,7 @@ def parse_args():
parser.add_argument("--share_secrets_ratio", type=float, default=1.0) parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
parser.add_argument("--cipher_time_window", type=int, default=300000) parser.add_argument("--cipher_time_window", type=int, default=300000)
parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3) parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
parser.add_argument("--config_file_path", type=str, default="")
parser.add_argument("--client_password", type=str, default="") parser.add_argument("--client_password", type=str, default="")
parser.add_argument("--server_password", type=str, default="") parser.add_argument("--server_password", type=str, default="")
parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False) parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False)
@ -77,7 +77,6 @@ def server_train(args):
start = time() start = time()
os.environ['CUDA_VISIBLE_DEVICES'] = args.device_id os.environ['CUDA_VISIBLE_DEVICES'] = args.device_id
model_path = args.model_path
output_dir = args.output_dir output_dir = args.output_dir
device_target = args.device_target device_target = args.device_target
@ -104,6 +103,7 @@ def server_train(args):
share_secrets_ratio = args.share_secrets_ratio share_secrets_ratio = args.share_secrets_ratio
cipher_time_window = args.cipher_time_window cipher_time_window = args.cipher_time_window
reconstruct_secrets_threshold = args.reconstruct_secrets_threshold reconstruct_secrets_threshold = args.reconstruct_secrets_threshold
config_file_path = args.config_file_path
client_password = args.client_password client_password = args.client_password
server_password = args.server_password server_password = args.server_password
enable_ssl = args.enable_ssl enable_ssl = args.enable_ssl
@ -136,6 +136,7 @@ def server_train(args):
"share_secrets_ratio": share_secrets_ratio, "share_secrets_ratio": share_secrets_ratio,
"cipher_time_window": cipher_time_window, "cipher_time_window": cipher_time_window,
"reconstruct_secrets_threshold": reconstruct_secrets_threshold, "reconstruct_secrets_threshold": reconstruct_secrets_threshold,
"config_file_path": config_file_path,
"client_password": client_password, "client_password": client_password,
"server_password": server_password, "server_password": server_password,
"enable_ssl": enable_ssl "enable_ssl": enable_ssl
@ -160,11 +161,6 @@ def server_train(args):
sys.stdout.flush() sys.stdout.flush()
start = time() start = time()
# train prepare
param_dict = load_checkpoint(model_path)
if 'learning_rate' in param_dict:
del param_dict['learning_rate']
# server optimizer # server optimizer
server_params = [_ for _ in network_with_cls_loss.trainable_params()] server_params = [_ for _ in network_with_cls_loss.trainable_params()]
server_decay_params = list( server_decay_params = list(
@ -183,8 +179,6 @@ def server_train(args):
eps=train_cfg.optimizer_cfg.AdamWeightDecay.eps) eps=train_cfg.optimizer_cfg.AdamWeightDecay.eps)
server_network_train_cell = NetworkTrainCell(network_with_cls_loss, optimizer=server_optimizer) server_network_train_cell = NetworkTrainCell(network_with_cls_loss, optimizer=server_optimizer)
restore_params(server_network_train_cell, param_dict)
print('Optimizer construction is done! Time cost: {}'.format(time() - start)) print('Optimizer construction is done! Time cost: {}'.format(time() - start))
sys.stdout.flush() sys.stdout.flush()
start = time() start = time()

View File

@ -29,6 +29,7 @@ parser.add_argument("--scheduler_manage_port", type=int, default=11202)
parser.add_argument("--client_password", type=str, default="") parser.add_argument("--client_password", type=str, default="")
parser.add_argument("--server_password", type=str, default="") parser.add_argument("--server_password", type=str, default="")
parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False) parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False)
parser.add_argument("--config_file_path", type=str, default="")
args, _ = parser.parse_known_args() args, _ = parser.parse_known_args()
device_target = args.device_target device_target = args.device_target
@ -41,6 +42,7 @@ scheduler_manage_port = args.scheduler_manage_port
client_password = args.client_password client_password = args.client_password
server_password = args.server_password server_password = args.server_password
enable_ssl = args.enable_ssl enable_ssl = args.enable_ssl
config_file_path = args.config_file_path
os.environ['MS_NODE_ID'] = "20" os.environ['MS_NODE_ID'] = "20"
cmd_sched = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && rm -rf ${execute_path}/scheduler/ &&" cmd_sched = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && rm -rf ${execute_path}/scheduler/ &&"
@ -58,6 +60,7 @@ cmd_sched += " --client_password=" + str(client_password)
cmd_sched += " --server_password=" + str(server_password) cmd_sched += " --server_password=" + str(server_password)
cmd_sched += " --enable_ssl=" + str(enable_ssl) cmd_sched += " --enable_ssl=" + str(enable_ssl)
cmd_sched += " --scheduler_manage_port=" + str(scheduler_manage_port) cmd_sched += " --scheduler_manage_port=" + str(scheduler_manage_port)
cmd_sched += " --config_file_path=" + config_file_path
cmd_sched += " > scheduler.log 2>&1 &" cmd_sched += " > scheduler.log 2>&1 &"
subprocess.call(['bash', '-c', cmd_sched]) subprocess.call(['bash', '-c', cmd_sched])

View File

@ -49,6 +49,7 @@ parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
parser.add_argument("--client_password", type=str, default="") parser.add_argument("--client_password", type=str, default="")
parser.add_argument("--server_password", type=str, default="") parser.add_argument("--server_password", type=str, default="")
parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False) parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False)
parser.add_argument("--config_file_path", type=str, default="")
args, _ = parser.parse_known_args() args, _ = parser.parse_known_args()
device_target = args.device_target device_target = args.device_target
@ -78,6 +79,7 @@ reconstruct_secrets_threshold = args.reconstruct_secrets_threshold
client_password = args.client_password client_password = args.client_password
server_password = args.server_password server_password = args.server_password
enable_ssl = args.enable_ssl enable_ssl = args.enable_ssl
config_file_path = args.config_file_path
if local_server_num == -1: if local_server_num == -1:
local_server_num = server_num local_server_num = server_num
@ -118,6 +120,7 @@ for i in range(local_server_num):
cmd_server += " --server_password=" + str(server_password) cmd_server += " --server_password=" + str(server_password)
cmd_server += " --enable_ssl=" + str(enable_ssl) cmd_server += " --enable_ssl=" + str(enable_ssl)
cmd_server += " --reconstruct_secrets_threshold=" + str(reconstruct_secrets_threshold) cmd_server += " --reconstruct_secrets_threshold=" + str(reconstruct_secrets_threshold)
cmd_server += " --config_file_path=" + config_file_path
cmd_server += " > server.log 2>&1 &" cmd_server += " > server.log 2>&1 &"
import time import time