forked from OSSInnovation/mindspore
!6137 modify the model_zoo ckpt path
Merge pull request !6137 from TuDouNi/master
This commit is contained in:
commit
8b0793eb84
|
@ -132,7 +132,7 @@ config = ed({
|
|||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 1,
|
||||
"keep_checkpoint_max": 10,
|
||||
"save_checkpoint_path": "./checkpoint",
|
||||
"save_checkpoint_path": "./",
|
||||
|
||||
"mindrecord_dir": "../MindRecord_COCO_TRAIN",
|
||||
"coco_root": "./cocodataset/",
|
||||
|
|
|
@ -136,7 +136,8 @@ if __name__ == '__main__':
|
|||
if config.save_checkpoint:
|
||||
ckptconfig = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * dataset_size,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix='faster_rcnn', directory=config.save_checkpoint_path, config=ckptconfig)
|
||||
save_checkpoint_path = os.path.join(config.save_checkpoint_path, "ckpt_" + str(rank) + "/")
|
||||
ckpoint_cb = ModelCheckpoint(prefix='faster_rcnn', directory=save_checkpoint_path, config=ckptconfig)
|
||||
cb += [ckpoint_cb]
|
||||
|
||||
model = Model(net)
|
||||
|
|
|
@ -106,6 +106,7 @@ if __name__ == '__main__':
|
|||
context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)
|
||||
device_num = int(os.environ.get("DEVICE_NUM", 1))
|
||||
|
||||
rank = 0
|
||||
if device_target == "Ascend":
|
||||
if args_opt.device_id is not None:
|
||||
context.set_context(device_id=args_opt.device_id)
|
||||
|
@ -117,6 +118,7 @@ if __name__ == '__main__':
|
|||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
init()
|
||||
rank = get_rank()
|
||||
elif device_target == "GPU":
|
||||
init()
|
||||
|
||||
|
@ -124,6 +126,7 @@ if __name__ == '__main__':
|
|||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
rank = get_rank()
|
||||
else:
|
||||
raise ValueError("Unsupported platform.")
|
||||
|
||||
|
@ -200,14 +203,13 @@ if __name__ == '__main__':
|
|||
if device_target == "Ascend":
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},
|
||||
amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=loss_scale_manager)
|
||||
ckpt_save_dir = "./"
|
||||
else: # GPU
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},
|
||||
amp_level="O2", keep_batchnorm_fp32=True, loss_scale_manager=loss_scale_manager)
|
||||
ckpt_save_dir = "./ckpt_" + str(get_rank()) + "/"
|
||||
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
time_cb = TimeMonitor(data_size=batch_num)
|
||||
ckpt_save_dir = "./ckpt_" + str(rank) + "/"
|
||||
ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_" + args_opt.dataset_name, directory=ckpt_save_dir,
|
||||
config=config_ck)
|
||||
loss_cb = LossMonitor()
|
||||
|
|
|
@ -38,7 +38,7 @@ config_gpu = edict({
|
|||
'momentum': 0.9,
|
||||
'opt_eps': 1.0,
|
||||
'keep_checkpoint_max': 10,
|
||||
'ckpt_path': './checkpoint/',
|
||||
'ckpt_path': './',
|
||||
'is_save_on_master': 0,
|
||||
'dropout_keep_prob': 0.5,
|
||||
'has_bias': True,
|
||||
|
@ -65,7 +65,7 @@ config_ascend = edict({
|
|||
'momentum': 0.9,
|
||||
'opt_eps': 1.0,
|
||||
'keep_checkpoint_max': 10,
|
||||
'ckpt_path': './checkpoint/',
|
||||
'ckpt_path': './',
|
||||
'is_save_on_master': 0,
|
||||
'dropout_keep_prob': 0.8,
|
||||
'has_bias': False,
|
||||
|
|
|
@ -115,7 +115,8 @@ if __name__ == '__main__':
|
|||
time_cb = TimeMonitor(data_size=batches_per_epoch)
|
||||
callbacks = [loss_cb, time_cb]
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix=f"inceptionv3-rank{cfg.rank}", directory=cfg.ckpt_path, config=config_ck)
|
||||
save_ckpt_path = os.path.join(cfg.ckpt_path, 'ckpt_' + str(cfg.rank) + '/')
|
||||
ckpoint_cb = ModelCheckpoint(prefix=f"inceptionv3-rank{cfg.rank}", directory=save_ckpt_path, config=config_ck)
|
||||
if args_opt.is_distributed & cfg.is_save_on_master:
|
||||
if cfg.rank == 0:
|
||||
callbacks.append(ckpoint_cb)
|
||||
|
|
|
@ -139,7 +139,7 @@ config = ed({
|
|||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 1,
|
||||
"keep_checkpoint_max": 12,
|
||||
"save_checkpoint_path": "./checkpoint",
|
||||
"save_checkpoint_path": "./",
|
||||
|
||||
"mindrecord_dir": "/home/mask_rcnn/MindRecord_COCO2017_Train",
|
||||
"coco_root": "/home/mask_rcnn/coco2017/",
|
||||
|
|
|
@ -131,7 +131,8 @@ if __name__ == '__main__':
|
|||
if config.save_checkpoint:
|
||||
ckptconfig = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * dataset_size,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix='mask_rcnn', directory=config.save_checkpoint_path, config=ckptconfig)
|
||||
save_checkpoint_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/')
|
||||
ckpoint_cb = ModelCheckpoint(prefix='mask_rcnn', directory=save_checkpoint_path, config=ckptconfig)
|
||||
cb += [ckpoint_cb]
|
||||
|
||||
model = Model(net)
|
||||
|
|
|
@ -37,7 +37,8 @@ def set_config(args):
|
|||
"save_checkpoint_epochs": 1,
|
||||
"keep_checkpoint_max": 20,
|
||||
"save_checkpoint_path": "./checkpoint",
|
||||
"platform": args.platform
|
||||
"platform": args.platform,
|
||||
"run_distribute": False
|
||||
})
|
||||
config_gpu = ed({
|
||||
"num_classes": 1000,
|
||||
|
|
|
@ -76,14 +76,12 @@ def config_ckpoint(config, lr, step_size):
|
|||
if config.save_checkpoint:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
ckpt_save_dir = config.save_checkpoint_path
|
||||
|
||||
if config.platform == "GPU":
|
||||
if config.run_distribute:
|
||||
ckpt_save_dir += "ckpt_" + str(get_rank()) + "/"
|
||||
else:
|
||||
ckpt_save_dir += "ckpt_" + "/"
|
||||
rank = 0
|
||||
if config.run_distribute:
|
||||
rank = get_rank()
|
||||
|
||||
ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(rank) + "/"
|
||||
ckpt_cb = ModelCheckpoint(prefix="mobilenetV2", directory=ckpt_save_dir, config=config_ck)
|
||||
cb += [ckpt_cb]
|
||||
return cb
|
||||
|
|
|
@ -25,7 +25,7 @@ nasnet_a_mobile_config_gpu = edict({
|
|||
'work_nums': 8,
|
||||
'epoch_size': 312,
|
||||
'keep_checkpoint_max': 100,
|
||||
'ckpt_path': './nasnet_a_mobile_checkpoint/',
|
||||
'ckpt_path': './',
|
||||
'is_save_on_master': 0,
|
||||
|
||||
### Dataset Config
|
||||
|
|
|
@ -102,7 +102,8 @@ if __name__ == '__main__':
|
|||
time_cb = TimeMonitor(data_size=batches_per_epoch)
|
||||
callbacks = [loss_cb, time_cb]
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix=f"nasnet-a-mobile-rank{cfg.rank}", directory=cfg.ckpt_path, config=config_ck)
|
||||
save_ckpt_path = os.path.join(cfg.ckpt_path, 'ckpt_' + str(cfg.rank) + '/')
|
||||
ckpoint_cb = ModelCheckpoint(prefix=f"nasnet-a-mobile-rank{cfg.rank}", directory=save_ckpt_path, config=config_ck)
|
||||
if args_opt.is_distributed & cfg.is_save_on_master:
|
||||
if cfg.rank == 0:
|
||||
callbacks.append(ckpoint_cb)
|
||||
|
|
|
@ -90,7 +90,7 @@ if __name__ == '__main__':
|
|||
gradients_mean=True)
|
||||
if args_opt.net == "resnet50":
|
||||
context.set_auto_parallel_context(all_reduce_fusion_config=[85, 160])
|
||||
ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
|
||||
ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
|
||||
|
||||
# create dataset
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=1,
|
||||
|
|
|
@ -100,7 +100,7 @@ if __name__ == '__main__':
|
|||
init()
|
||||
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True, all_reduce_fusion_config=[107])
|
||||
ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
|
||||
ckpt_save_dir = ckpt_save_dir + "ckpt_" + str(get_rank()) + "/"
|
||||
|
||||
# create dataset
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=1,
|
||||
|
|
|
@ -280,8 +280,9 @@ def train(cloud_args=None):
|
|||
if args.rank_save_ckpt_flag:
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * args.steps_per_epoch,
|
||||
keep_checkpoint_max=args.ckpt_save_max)
|
||||
save_ckpt_path = os.path.join(args.outputs_dir, 'ckpt_' + str(args.rank) + '/')
|
||||
ckpt_cb = ModelCheckpoint(config=ckpt_config,
|
||||
directory=args.outputs_dir,
|
||||
directory=save_ckpt_path,
|
||||
prefix='{}'.format(args.rank))
|
||||
callbacks.append(ckpt_cb)
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ config_gpu = edict({
|
|||
'work_nums': 8,
|
||||
'epoch_size': 250,
|
||||
'keep_checkpoint_max': 100,
|
||||
'ckpt_path': './checkpoint/',
|
||||
'ckpt_path': './',
|
||||
'is_save_on_master': 0,
|
||||
|
||||
### Dataset Config
|
||||
|
|
|
@ -110,7 +110,8 @@ if __name__ == '__main__':
|
|||
time_cb = TimeMonitor(data_size=batches_per_epoch)
|
||||
callbacks = [loss_cb, time_cb]
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix=f"shufflenet-rank{cfg.rank}", directory=cfg.ckpt_path, config=config_ck)
|
||||
save_ckpt_path = os.path.join(cfg.ckpt_path, 'ckpt_' + str(cfg.rank) + '/')
|
||||
ckpoint_cb = ModelCheckpoint(prefix=f"shufflenet-rank{cfg.rank}", directory=save_ckpt_path, config=config_ck)
|
||||
if args_opt.is_distributed & cfg.is_save_on_master:
|
||||
if cfg.rank == 0:
|
||||
callbacks.append(ckpoint_cb)
|
||||
|
|
|
@ -118,7 +118,8 @@ def main():
|
|||
|
||||
# checkpoint
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="ssd", directory=None, config=ckpt_config)
|
||||
save_ckpt_path = './ckpt_' + str(rank) + '/'
|
||||
ckpoint_cb = ModelCheckpoint(prefix="ssd", directory=save_ckpt_path, config=ckpt_config)
|
||||
|
||||
if args_opt.pre_trained:
|
||||
if args_opt.pre_trained_epoch_size <= 0:
|
||||
|
|
|
@ -226,8 +226,9 @@ if __name__ == '__main__':
|
|||
if args.rank_save_ckpt_flag:
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * args.steps_per_epoch,
|
||||
keep_checkpoint_max=args.ckpt_save_max)
|
||||
save_ckpt_path = os.path.join(args.outputs_dir, 'ckpt_' + str(args.rank) + '/')
|
||||
ckpt_cb = ModelCheckpoint(config=ckpt_config,
|
||||
directory=args.outputs_dir,
|
||||
directory=save_ckpt_path,
|
||||
prefix='{}'.format(args.rank))
|
||||
callbacks.append(ckpt_cb)
|
||||
|
||||
|
|
|
@ -27,5 +27,5 @@ config = EasyDict({
|
|||
"save_checkpoint": True,
|
||||
"save_checkpoint_steps": 97,
|
||||
"keep_checkpoint_max": 30,
|
||||
"save_checkpoint_path": "./checkpoint",
|
||||
"save_checkpoint_path": "./",
|
||||
})
|
||||
|
|
|
@ -98,6 +98,7 @@ if __name__ == '__main__':
|
|||
if cf.save_checkpoint:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=cf.save_checkpoint_steps,
|
||||
keep_checkpoint_max=cf.keep_checkpoint_max)
|
||||
ckpt_cb = ModelCheckpoint(prefix="warpctc", directory=cf.save_checkpoint_path + str(rank), config=config_ck)
|
||||
save_ckpt_path = os.path.join(cf.save_checkpoint_path, 'ckpt_' + str(rank) + '/')
|
||||
ckpt_cb = ModelCheckpoint(prefix="warpctc", directory=save_ckpt_path, config=config_ck)
|
||||
callbacks.append(ckpt_cb)
|
||||
model.train(cf.epoch_size, dataset, callbacks=callbacks)
|
||||
|
|
|
@ -287,8 +287,9 @@ def train():
|
|||
ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval,
|
||||
keep_checkpoint_max=ckpt_max_num)
|
||||
save_ckpt_path = os.path.join(args.outputs_dir, 'ckpt_' + str(args.rank) + '/')
|
||||
ckpt_cb = ModelCheckpoint(config=ckpt_config,
|
||||
directory=args.outputs_dir,
|
||||
directory=save_ckpt_path,
|
||||
prefix='{}'.format(args.rank))
|
||||
cb_params = _InternalCallbackParam()
|
||||
cb_params.train_network = network
|
||||
|
|
|
@ -291,8 +291,9 @@ def train():
|
|||
ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval,
|
||||
keep_checkpoint_max=ckpt_max_num)
|
||||
save_ckpt_path = os.path.join(args.outputs_dir, 'ckpt_' + str(args.rank) + '/')
|
||||
ckpt_cb = ModelCheckpoint(config=ckpt_config,
|
||||
directory=args.outputs_dir,
|
||||
directory=save_ckpt_path,
|
||||
prefix='{}'.format(args.rank))
|
||||
cb_params = _InternalCallbackParam()
|
||||
cb_params.train_network = network
|
||||
|
|
|
@ -137,7 +137,7 @@ def main():
|
|||
|
||||
# checkpoint
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="yolov3", directory=None, config=ckpt_config)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="yolov3", directory='./ckpt_' + str(rank) + '/', config=ckpt_config)
|
||||
|
||||
if args_opt.pre_trained:
|
||||
if args_opt.pre_trained_epoch_size <= 0:
|
||||
|
|
|
@ -20,6 +20,7 @@ python run_pretrain.py
|
|||
import os
|
||||
import argparse
|
||||
import mindspore.communication.management as D
|
||||
from mindspore.communication.management import get_rank
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import context
|
||||
from mindspore.train.model import Model
|
||||
|
@ -82,7 +83,7 @@ def run_pretrain():
|
|||
D.init()
|
||||
device_num = D.get_group_size()
|
||||
rank = D.get_rank()
|
||||
ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(rank) + '/'
|
||||
ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(get_rank()) + '/'
|
||||
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
|
||||
|
|
|
@ -28,6 +28,7 @@ from src.model_thor import Model
|
|||
from src.utils import LossCallBack, BertLearningRate
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.communication.management as D
|
||||
from mindspore.communication.management import get_rank
|
||||
from mindspore import context
|
||||
from mindspore import log as logger
|
||||
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay
|
||||
|
@ -84,7 +85,7 @@ def run_pretrain():
|
|||
D.init()
|
||||
device_num = D.get_group_size()
|
||||
rank = D.get_rank()
|
||||
ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(rank) + '/'
|
||||
ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(get_rank()) + '/'
|
||||
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# ============================================================================
|
||||
"""Transformer training script."""
|
||||
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import ast
|
||||
|
@ -27,6 +28,7 @@ from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
|
|||
from mindspore.train.callback import Callback, TimeMonitor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
import mindspore.communication.management as D
|
||||
from mindspore.communication.management import get_rank
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore import context
|
||||
from mindspore.common import set_seed
|
||||
|
@ -125,6 +127,7 @@ def run_transformer_train():
|
|||
parameter_broadcast=True, device_num=device_num)
|
||||
D.init()
|
||||
rank_id = args.device_id % device_num
|
||||
save_ckpt_path = os.path.join(args.save_checkpoint_path, 'ckpt_' + str(get_rank()) + '/')
|
||||
else:
|
||||
device_num = 1
|
||||
rank_id = 0
|
||||
|
@ -153,7 +156,7 @@ def run_transformer_train():
|
|||
if device_num == 1 or (device_num > 1 and rank_id == 0):
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.save_checkpoint_steps,
|
||||
keep_checkpoint_max=args.save_checkpoint_num)
|
||||
ckpoint_cb = ModelCheckpoint(prefix='transformer', directory=args.save_checkpoint_path, config=ckpt_config)
|
||||
ckpoint_cb = ModelCheckpoint(prefix='transformer', directory=save_ckpt_path, config=ckpt_config)
|
||||
callbacks.append(ckpoint_cb)
|
||||
|
||||
if args.enable_lossscale == "true":
|
||||
|
|
|
@ -104,6 +104,7 @@ if __name__ == '__main__':
|
|||
if train_config.save_checkpoint:
|
||||
if rank_size:
|
||||
train_config.ckpt_file_name_prefix = train_config.ckpt_file_name_prefix + str(get_rank())
|
||||
args_opt.ckpt_path = os.path.join(args_opt.ckpt_path, 'ckpt_' + str(get_rank()) + '/')
|
||||
if args_opt.device_target == "GPU":
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=steps_size,
|
||||
keep_checkpoint_max=train_config.keep_checkpoint_max)
|
||||
|
|
|
@ -127,7 +127,7 @@ def train_and_eval(config):
|
|||
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size()*epochs,
|
||||
keep_checkpoint_max=5, integrated_save=False)
|
||||
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
|
||||
directory=config.ckpt_path, config=ckptconfig)
|
||||
directory=config.ckpt_path + '/ckpt_' + str(get_rank()) + '/', config=ckptconfig)
|
||||
context.set_auto_parallel_context(strategy_ckpt_save_file=config.stra_ckpt)
|
||||
callback_list = [TimeMonitor(
|
||||
ds_train.get_dataset_size()), eval_callback, callback]
|
||||
|
|
|
@ -102,7 +102,8 @@ def train_and_eval(config):
|
|||
callback = LossCallBack(config=config)
|
||||
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5)
|
||||
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
|
||||
directory=config.ckpt_path, config=ckptconfig)
|
||||
directory=config.ckpt_path + '/ckpt_' + str(get_rank()) + '/',
|
||||
config=ckptconfig)
|
||||
out = model.eval(ds_eval)
|
||||
print("=====" * 5 + "model.eval() initialized: {}".format(out))
|
||||
callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback]
|
||||
|
|
|
@ -104,7 +104,8 @@ def train_and_eval(config):
|
|||
callback = LossCallBack(config=config)
|
||||
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5)
|
||||
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
|
||||
directory=config.ckpt_path, config=ckptconfig)
|
||||
directory=config.ckpt_path + '/ckpt_' + str(get_rank()) + '/',
|
||||
config=ckptconfig)
|
||||
callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback]
|
||||
if get_rank() == 0:
|
||||
callback_list.append(ckpoint_cb)
|
||||
|
|
|
@ -98,7 +98,8 @@ def train_and_eval(config):
|
|||
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size()*config.epochs,
|
||||
keep_checkpoint_max=10)
|
||||
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
|
||||
directory=config.ckpt_path, config=ckptconfig)
|
||||
directory=config.ckpt_path + '/ckpt_' + str(get_rank()) + '/',
|
||||
config=ckptconfig)
|
||||
callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback]
|
||||
if int(get_rank()) == 0:
|
||||
callback_list.append(ckpoint_cb)
|
||||
|
|
Loading…
Reference in New Issue