forked from mindspore-Ecosystem/mindspore
modify config param
This commit is contained in:
parent
4fff5f442c
commit
efe41902b8
|
@ -51,8 +51,8 @@ Parameters for both training and evaluating can be set in config.py.
|
|||
"image_height": 224, # image height
|
||||
"image_width": 224, # image width
|
||||
"save_checkpoint": True, # whether save checkpoint or not
|
||||
"save_checkpoint_steps": 500, # the step interval between two checkpoints. By default, the last checkpoint will be saved after the last step
|
||||
"keep_checkpoint_max": 40, # only keep the last keep_checkpoint_max checkpoint
|
||||
"save_checkpoint_epochs": 1, # the epoch interval between two checkpoints. By default, the last checkpoint will be saved after the last epoch
|
||||
"keep_checkpoint_max": 10, # only keep the last keep_checkpoint_max checkpoint
|
||||
"save_checkpoint_path": "./", # path to save checkpoint relative to the executed path
|
||||
"warmup_epochs": 0, # number of warmup epoch
|
||||
"lr_decay_mode": "cosine" # decay mode for generating learning rate
|
||||
|
|
|
@ -28,8 +28,8 @@ config = ed({
|
|||
"image_height": 224,
|
||||
"image_width": 224,
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_steps": 500,
|
||||
"keep_checkpoint_max": 40,
|
||||
"save_checkpoint_epochs": 1,
|
||||
"keep_checkpoint_max": 10,
|
||||
"save_checkpoint_path": "./",
|
||||
"warmup_epochs": 0,
|
||||
"lr_decay_mode": "cosine",
|
||||
|
|
|
@ -54,7 +54,7 @@ if __name__ == '__main__':
|
|||
if not args_opt.do_eval and args_opt.run_distribute:
|
||||
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
mirror_mean=True, parameter_broadcast=True)
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([140])
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([180, 313])
|
||||
init()
|
||||
|
||||
epoch_size = config.epoch_size
|
||||
|
|
|
@ -59,7 +59,7 @@ if __name__ == '__main__':
|
|||
if not args_opt.do_eval and args_opt.run_distribute:
|
||||
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
mirror_mean=True, parameter_broadcast=True)
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([140])
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([180, 313])
|
||||
init()
|
||||
|
||||
epoch_size = config.epoch_size
|
||||
|
@ -91,7 +91,7 @@ if __name__ == '__main__':
|
|||
loss_cb = LossMonitor()
|
||||
cb = [time_cb, loss_cb]
|
||||
if config.save_checkpoint:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs*step_size,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
ckpt_cb = ModelCheckpoint(prefix="resnet", directory=config.save_checkpoint_path, config=config_ck)
|
||||
cb += [ckpt_cb]
|
||||
|
|
Loading…
Reference in New Issue