!1717 fix bug introduced by gpu support
Merge pull request !1717 from gengdongjie/master
This commit is contained in:
commit
9eee55969a
|
@ -29,7 +29,7 @@ config = ed({
|
||||||
"image_height": 224,
|
"image_height": 224,
|
||||||
"image_width": 224,
|
"image_width": 224,
|
||||||
"save_checkpoint": True,
|
"save_checkpoint": True,
|
||||||
"save_checkpoint_epochs": 1,
|
"save_checkpoint_epochs": 5,
|
||||||
"keep_checkpoint_max": 10,
|
"keep_checkpoint_max": 10,
|
||||||
"save_checkpoint_path": "./",
|
"save_checkpoint_path": "./",
|
||||||
"warmup_epochs": 0,
|
"warmup_epochs": 0,
|
||||||
|
|
|
@ -28,7 +28,7 @@ config = ed({
|
||||||
"image_height": 224,
|
"image_height": 224,
|
||||||
"image_width": 224,
|
"image_width": 224,
|
||||||
"save_checkpoint": True,
|
"save_checkpoint": True,
|
||||||
"save_checkpoint_steps": 1950,
|
"save_checkpoint_epochs": 5,
|
||||||
"keep_checkpoint_max": 10,
|
"keep_checkpoint_max": 10,
|
||||||
"save_checkpoint_path": "./",
|
"save_checkpoint_path": "./",
|
||||||
"warmup_epochs": 5,
|
"warmup_epochs": 5,
|
||||||
|
|
|
@ -43,6 +43,8 @@ args_opt = parser.parse_args()
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
target = args_opt.device_target
|
target = args_opt.device_target
|
||||||
|
ckpt_save_dir = config.save_checkpoint_path
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
|
||||||
if not args_opt.do_eval and args_opt.run_distribute:
|
if not args_opt.do_eval and args_opt.run_distribute:
|
||||||
if target == "Ascend":
|
if target == "Ascend":
|
||||||
device_id = int(os.getenv('DEVICE_ID'))
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
|
@ -80,13 +82,13 @@ if __name__ == '__main__':
|
||||||
else:
|
else:
|
||||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||||
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'},
|
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'},
|
||||||
amp_level="O2", keep_batchnorm_fp32=True)
|
amp_level="O2", keep_batchnorm_fp32=False)
|
||||||
|
|
||||||
time_cb = TimeMonitor(data_size=step_size)
|
time_cb = TimeMonitor(data_size=step_size)
|
||||||
loss_cb = LossMonitor()
|
loss_cb = LossMonitor()
|
||||||
cb = [time_cb, loss_cb]
|
cb = [time_cb, loss_cb]
|
||||||
if config.save_checkpoint:
|
if config.save_checkpoint:
|
||||||
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
|
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs*step_size,
|
||||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||||
ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=config_ck)
|
ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=config_ck)
|
||||||
cb += [ckpt_cb]
|
cb += [ckpt_cb]
|
||||||
|
|
|
@ -29,7 +29,7 @@ config = ed({
|
||||||
"image_height": 224,
|
"image_height": 224,
|
||||||
"image_width": 224,
|
"image_width": 224,
|
||||||
"save_checkpoint": True,
|
"save_checkpoint": True,
|
||||||
"save_checkpoint_epochs": 1,
|
"save_checkpoint_epochs": 5,
|
||||||
"keep_checkpoint_max": 10,
|
"keep_checkpoint_max": 10,
|
||||||
"save_checkpoint_path": "./",
|
"save_checkpoint_path": "./",
|
||||||
"warmup_epochs": 0,
|
"warmup_epochs": 0,
|
||||||
|
|
|
@ -46,6 +46,8 @@ args_opt = parser.parse_args()
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
target = args_opt.device_target
|
target = args_opt.device_target
|
||||||
|
ckpt_save_dir = config.save_checkpoint_path
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
|
||||||
if not args_opt.do_eval and args_opt.run_distribute:
|
if not args_opt.do_eval and args_opt.run_distribute:
|
||||||
if target == "Ascend":
|
if target == "Ascend":
|
||||||
device_id = int(os.getenv('DEVICE_ID'))
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
|
|
Loading…
Reference in New Issue