forked from mindspore-Ecosystem/mindspore
!23986 add config file path
Merge pull request !23986 from wtcheng/master
This commit is contained in:
commit
3e1ef7823a
|
@ -20,10 +20,9 @@ import ast
|
|||
from time import time
|
||||
import numpy as np
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import save_checkpoint, load_checkpoint
|
||||
from mindspore.train.serialization import save_checkpoint
|
||||
from src.adam import AdamWeightDecayOp as AdamWeightDecay
|
||||
from src.config import train_cfg, server_net_cfg
|
||||
from src.utils import restore_params
|
||||
from src.model import AlbertModelCLS
|
||||
from src.cell_wrapper import NetworkWithCLSLoss, NetworkTrainCell
|
||||
|
||||
|
@ -67,6 +66,7 @@ def parse_args():
|
|||
parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
|
||||
parser.add_argument("--cipher_time_window", type=int, default=300000)
|
||||
parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
|
||||
parser.add_argument("--config_file_path", type=str, default="")
|
||||
parser.add_argument("--client_password", type=str, default="")
|
||||
parser.add_argument("--server_password", type=str, default="")
|
||||
parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False)
|
||||
|
@ -77,7 +77,6 @@ def server_train(args):
|
|||
start = time()
|
||||
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = args.device_id
|
||||
model_path = args.model_path
|
||||
output_dir = args.output_dir
|
||||
|
||||
device_target = args.device_target
|
||||
|
@ -104,6 +103,7 @@ def server_train(args):
|
|||
share_secrets_ratio = args.share_secrets_ratio
|
||||
cipher_time_window = args.cipher_time_window
|
||||
reconstruct_secrets_threshold = args.reconstruct_secrets_threshold
|
||||
config_file_path = args.config_file_path
|
||||
client_password = args.client_password
|
||||
server_password = args.server_password
|
||||
enable_ssl = args.enable_ssl
|
||||
|
@ -136,6 +136,7 @@ def server_train(args):
|
|||
"share_secrets_ratio": share_secrets_ratio,
|
||||
"cipher_time_window": cipher_time_window,
|
||||
"reconstruct_secrets_threshold": reconstruct_secrets_threshold,
|
||||
"config_file_path": config_file_path,
|
||||
"client_password": client_password,
|
||||
"server_password": server_password,
|
||||
"enable_ssl": enable_ssl
|
||||
|
@ -160,11 +161,6 @@ def server_train(args):
|
|||
sys.stdout.flush()
|
||||
start = time()
|
||||
|
||||
# train prepare
|
||||
param_dict = load_checkpoint(model_path)
|
||||
if 'learning_rate' in param_dict:
|
||||
del param_dict['learning_rate']
|
||||
|
||||
# server optimizer
|
||||
server_params = [_ for _ in network_with_cls_loss.trainable_params()]
|
||||
server_decay_params = list(
|
||||
|
@ -183,8 +179,6 @@ def server_train(args):
|
|||
eps=train_cfg.optimizer_cfg.AdamWeightDecay.eps)
|
||||
server_network_train_cell = NetworkTrainCell(network_with_cls_loss, optimizer=server_optimizer)
|
||||
|
||||
restore_params(server_network_train_cell, param_dict)
|
||||
|
||||
print('Optimizer construction is done! Time cost: {}'.format(time() - start))
|
||||
sys.stdout.flush()
|
||||
start = time()
|
||||
|
|
|
@ -29,6 +29,7 @@ parser.add_argument("--scheduler_manage_port", type=int, default=11202)
|
|||
parser.add_argument("--client_password", type=str, default="")
|
||||
parser.add_argument("--server_password", type=str, default="")
|
||||
parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False)
|
||||
parser.add_argument("--config_file_path", type=str, default="")
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
|
@ -41,6 +42,7 @@ scheduler_manage_port = args.scheduler_manage_port
|
|||
client_password = args.client_password
|
||||
server_password = args.server_password
|
||||
enable_ssl = args.enable_ssl
|
||||
config_file_path = args.config_file_path
|
||||
|
||||
os.environ['MS_NODE_ID'] = "20"
|
||||
cmd_sched = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && rm -rf ${execute_path}/scheduler/ &&"
|
||||
|
@ -58,6 +60,7 @@ cmd_sched += " --client_password=" + str(client_password)
|
|||
cmd_sched += " --server_password=" + str(server_password)
|
||||
cmd_sched += " --enable_ssl=" + str(enable_ssl)
|
||||
cmd_sched += " --scheduler_manage_port=" + str(scheduler_manage_port)
|
||||
cmd_sched += " --config_file_path=" + config_file_path
|
||||
cmd_sched += " > scheduler.log 2>&1 &"
|
||||
|
||||
subprocess.call(['bash', '-c', cmd_sched])
|
||||
|
|
|
@ -49,6 +49,7 @@ parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
|
|||
parser.add_argument("--client_password", type=str, default="")
|
||||
parser.add_argument("--server_password", type=str, default="")
|
||||
parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False)
|
||||
parser.add_argument("--config_file_path", type=str, default="")
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
|
@ -78,6 +79,7 @@ reconstruct_secrets_threshold = args.reconstruct_secrets_threshold
|
|||
client_password = args.client_password
|
||||
server_password = args.server_password
|
||||
enable_ssl = args.enable_ssl
|
||||
config_file_path = args.config_file_path
|
||||
|
||||
if local_server_num == -1:
|
||||
local_server_num = server_num
|
||||
|
@ -118,6 +120,7 @@ for i in range(local_server_num):
|
|||
cmd_server += " --server_password=" + str(server_password)
|
||||
cmd_server += " --enable_ssl=" + str(enable_ssl)
|
||||
cmd_server += " --reconstruct_secrets_threshold=" + str(reconstruct_secrets_threshold)
|
||||
cmd_server += " --config_file_path=" + config_file_path
|
||||
cmd_server += " > server.log 2>&1 &"
|
||||
|
||||
import time
|
||||
|
|
Loading…
Reference in New Issue