forked from mindspore-Ecosystem/mindspore
add pretrained and update launch
This commit is contained in:
parent
cee048f04c
commit
1e52582e23
|
@ -130,9 +130,11 @@ def main():
|
|||
log_files = []
|
||||
env = os.environ.copy()
|
||||
env['RANK_SIZE'] = str(args.nproc_per_node)
|
||||
cur_path = os.getcwd()
|
||||
for rank_id in range(0, args.nproc_per_node):
|
||||
os.chdir(cur_path)
|
||||
device_id = visible_devices[rank_id]
|
||||
device_dir = os.path.join(os.getcwd(), 'device{}'.format(rank_id))
|
||||
device_dir = os.path.join(cur_path, 'device{}'.format(rank_id))
|
||||
env['RANK_ID'] = str(rank_id)
|
||||
env['DEVICE_ID'] = str(device_id)
|
||||
if args.nproc_per_node > 1:
|
||||
|
@ -141,11 +143,12 @@ def main():
|
|||
if os.path.exists(device_dir):
|
||||
shutil.rmtree(device_dir)
|
||||
os.mkdir(device_dir)
|
||||
os.chdir(device_dir)
|
||||
cmd = [sys.executable, '-u']
|
||||
cmd.append(args.training_script)
|
||||
cmd.extend(args.training_script_args)
|
||||
log_file = open('{dir}/log{id}.log'.format(dir=device_dir, id=rank_id), 'w')
|
||||
process = subprocess.Popen(cmd, stdout=log_file, env=env)
|
||||
process = subprocess.Popen(cmd, stdout=log_file, stderr=log_file, env=env)
|
||||
processes.append(process)
|
||||
cmds.append(cmd)
|
||||
log_files.append(log_file)
|
||||
|
|
|
@ -37,6 +37,7 @@ from mindspore.train.model import Model, ParallelMode
|
|||
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
import mindspore.dataset.engine as de
|
||||
from mindspore.communication.management import init
|
||||
|
||||
|
@ -46,6 +47,7 @@ de.config.set_seed(1)
|
|||
|
||||
parser = argparse.ArgumentParser(description='Image classification')
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||
parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
|
@ -166,6 +168,9 @@ if __name__ == '__main__':
|
|||
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True,
|
||||
repeat_num=epoch_size, batch_size=config.batch_size)
|
||||
step_size = dataset.get_dataset_size()
|
||||
if args_opt.pre_trained:
|
||||
param_dict = load_checkpoint(args_opt.pre_trained)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||
lr = Tensor(get_lr(global_step=0, lr_init=0, lr_end=0, lr_max=config.lr,
|
||||
|
|
Loading…
Reference in New Issue