forked from mindspore-Ecosystem/mindspore
!23986 add config file path
Merge pull request !23986 from wtcheng/master
This commit is contained in:
commit
3e1ef7823a
|
@ -20,10 +20,9 @@ import ast
|
||||||
from time import time
|
from time import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mindspore import context, Tensor
|
from mindspore import context, Tensor
|
||||||
from mindspore.train.serialization import save_checkpoint, load_checkpoint
|
from mindspore.train.serialization import save_checkpoint
|
||||||
from src.adam import AdamWeightDecayOp as AdamWeightDecay
|
from src.adam import AdamWeightDecayOp as AdamWeightDecay
|
||||||
from src.config import train_cfg, server_net_cfg
|
from src.config import train_cfg, server_net_cfg
|
||||||
from src.utils import restore_params
|
|
||||||
from src.model import AlbertModelCLS
|
from src.model import AlbertModelCLS
|
||||||
from src.cell_wrapper import NetworkWithCLSLoss, NetworkTrainCell
|
from src.cell_wrapper import NetworkWithCLSLoss, NetworkTrainCell
|
||||||
|
|
||||||
|
@ -67,6 +66,7 @@ def parse_args():
|
||||||
parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
|
parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
|
||||||
parser.add_argument("--cipher_time_window", type=int, default=300000)
|
parser.add_argument("--cipher_time_window", type=int, default=300000)
|
||||||
parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
|
parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
|
||||||
|
parser.add_argument("--config_file_path", type=str, default="")
|
||||||
parser.add_argument("--client_password", type=str, default="")
|
parser.add_argument("--client_password", type=str, default="")
|
||||||
parser.add_argument("--server_password", type=str, default="")
|
parser.add_argument("--server_password", type=str, default="")
|
||||||
parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False)
|
parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False)
|
||||||
|
@ -77,7 +77,6 @@ def server_train(args):
|
||||||
start = time()
|
start = time()
|
||||||
|
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = args.device_id
|
os.environ['CUDA_VISIBLE_DEVICES'] = args.device_id
|
||||||
model_path = args.model_path
|
|
||||||
output_dir = args.output_dir
|
output_dir = args.output_dir
|
||||||
|
|
||||||
device_target = args.device_target
|
device_target = args.device_target
|
||||||
|
@ -104,6 +103,7 @@ def server_train(args):
|
||||||
share_secrets_ratio = args.share_secrets_ratio
|
share_secrets_ratio = args.share_secrets_ratio
|
||||||
cipher_time_window = args.cipher_time_window
|
cipher_time_window = args.cipher_time_window
|
||||||
reconstruct_secrets_threshold = args.reconstruct_secrets_threshold
|
reconstruct_secrets_threshold = args.reconstruct_secrets_threshold
|
||||||
|
config_file_path = args.config_file_path
|
||||||
client_password = args.client_password
|
client_password = args.client_password
|
||||||
server_password = args.server_password
|
server_password = args.server_password
|
||||||
enable_ssl = args.enable_ssl
|
enable_ssl = args.enable_ssl
|
||||||
|
@ -136,6 +136,7 @@ def server_train(args):
|
||||||
"share_secrets_ratio": share_secrets_ratio,
|
"share_secrets_ratio": share_secrets_ratio,
|
||||||
"cipher_time_window": cipher_time_window,
|
"cipher_time_window": cipher_time_window,
|
||||||
"reconstruct_secrets_threshold": reconstruct_secrets_threshold,
|
"reconstruct_secrets_threshold": reconstruct_secrets_threshold,
|
||||||
|
"config_file_path": config_file_path,
|
||||||
"client_password": client_password,
|
"client_password": client_password,
|
||||||
"server_password": server_password,
|
"server_password": server_password,
|
||||||
"enable_ssl": enable_ssl
|
"enable_ssl": enable_ssl
|
||||||
|
@ -160,11 +161,6 @@ def server_train(args):
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
start = time()
|
start = time()
|
||||||
|
|
||||||
# train prepare
|
|
||||||
param_dict = load_checkpoint(model_path)
|
|
||||||
if 'learning_rate' in param_dict:
|
|
||||||
del param_dict['learning_rate']
|
|
||||||
|
|
||||||
# server optimizer
|
# server optimizer
|
||||||
server_params = [_ for _ in network_with_cls_loss.trainable_params()]
|
server_params = [_ for _ in network_with_cls_loss.trainable_params()]
|
||||||
server_decay_params = list(
|
server_decay_params = list(
|
||||||
|
@ -183,8 +179,6 @@ def server_train(args):
|
||||||
eps=train_cfg.optimizer_cfg.AdamWeightDecay.eps)
|
eps=train_cfg.optimizer_cfg.AdamWeightDecay.eps)
|
||||||
server_network_train_cell = NetworkTrainCell(network_with_cls_loss, optimizer=server_optimizer)
|
server_network_train_cell = NetworkTrainCell(network_with_cls_loss, optimizer=server_optimizer)
|
||||||
|
|
||||||
restore_params(server_network_train_cell, param_dict)
|
|
||||||
|
|
||||||
print('Optimizer construction is done! Time cost: {}'.format(time() - start))
|
print('Optimizer construction is done! Time cost: {}'.format(time() - start))
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
start = time()
|
start = time()
|
||||||
|
|
|
@ -29,6 +29,7 @@ parser.add_argument("--scheduler_manage_port", type=int, default=11202)
|
||||||
parser.add_argument("--client_password", type=str, default="")
|
parser.add_argument("--client_password", type=str, default="")
|
||||||
parser.add_argument("--server_password", type=str, default="")
|
parser.add_argument("--server_password", type=str, default="")
|
||||||
parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False)
|
parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False)
|
||||||
|
parser.add_argument("--config_file_path", type=str, default="")
|
||||||
|
|
||||||
args, _ = parser.parse_known_args()
|
args, _ = parser.parse_known_args()
|
||||||
device_target = args.device_target
|
device_target = args.device_target
|
||||||
|
@ -41,6 +42,7 @@ scheduler_manage_port = args.scheduler_manage_port
|
||||||
client_password = args.client_password
|
client_password = args.client_password
|
||||||
server_password = args.server_password
|
server_password = args.server_password
|
||||||
enable_ssl = args.enable_ssl
|
enable_ssl = args.enable_ssl
|
||||||
|
config_file_path = args.config_file_path
|
||||||
|
|
||||||
os.environ['MS_NODE_ID'] = "20"
|
os.environ['MS_NODE_ID'] = "20"
|
||||||
cmd_sched = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && rm -rf ${execute_path}/scheduler/ &&"
|
cmd_sched = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && rm -rf ${execute_path}/scheduler/ &&"
|
||||||
|
@ -58,6 +60,7 @@ cmd_sched += " --client_password=" + str(client_password)
|
||||||
cmd_sched += " --server_password=" + str(server_password)
|
cmd_sched += " --server_password=" + str(server_password)
|
||||||
cmd_sched += " --enable_ssl=" + str(enable_ssl)
|
cmd_sched += " --enable_ssl=" + str(enable_ssl)
|
||||||
cmd_sched += " --scheduler_manage_port=" + str(scheduler_manage_port)
|
cmd_sched += " --scheduler_manage_port=" + str(scheduler_manage_port)
|
||||||
|
cmd_sched += " --config_file_path=" + config_file_path
|
||||||
cmd_sched += " > scheduler.log 2>&1 &"
|
cmd_sched += " > scheduler.log 2>&1 &"
|
||||||
|
|
||||||
subprocess.call(['bash', '-c', cmd_sched])
|
subprocess.call(['bash', '-c', cmd_sched])
|
||||||
|
|
|
@ -49,6 +49,7 @@ parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
|
||||||
parser.add_argument("--client_password", type=str, default="")
|
parser.add_argument("--client_password", type=str, default="")
|
||||||
parser.add_argument("--server_password", type=str, default="")
|
parser.add_argument("--server_password", type=str, default="")
|
||||||
parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False)
|
parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False)
|
||||||
|
parser.add_argument("--config_file_path", type=str, default="")
|
||||||
|
|
||||||
args, _ = parser.parse_known_args()
|
args, _ = parser.parse_known_args()
|
||||||
device_target = args.device_target
|
device_target = args.device_target
|
||||||
|
@ -78,6 +79,7 @@ reconstruct_secrets_threshold = args.reconstruct_secrets_threshold
|
||||||
client_password = args.client_password
|
client_password = args.client_password
|
||||||
server_password = args.server_password
|
server_password = args.server_password
|
||||||
enable_ssl = args.enable_ssl
|
enable_ssl = args.enable_ssl
|
||||||
|
config_file_path = args.config_file_path
|
||||||
|
|
||||||
if local_server_num == -1:
|
if local_server_num == -1:
|
||||||
local_server_num = server_num
|
local_server_num = server_num
|
||||||
|
@ -118,6 +120,7 @@ for i in range(local_server_num):
|
||||||
cmd_server += " --server_password=" + str(server_password)
|
cmd_server += " --server_password=" + str(server_password)
|
||||||
cmd_server += " --enable_ssl=" + str(enable_ssl)
|
cmd_server += " --enable_ssl=" + str(enable_ssl)
|
||||||
cmd_server += " --reconstruct_secrets_threshold=" + str(reconstruct_secrets_threshold)
|
cmd_server += " --reconstruct_secrets_threshold=" + str(reconstruct_secrets_threshold)
|
||||||
|
cmd_server += " --config_file_path=" + config_file_path
|
||||||
cmd_server += " > server.log 2>&1 &"
|
cmd_server += " > server.log 2>&1 &"
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
|
Loading…
Reference in New Issue