forked from mindspore-Ecosystem/mindspore
move set_context before init
This commit is contained in:
parent
29db53c2ba
commit
db124ce8fc
|
@ -112,6 +112,10 @@ args.lr_epochs = list(map(int, args.lr_epochs.split(',')))
|
|||
args.data_root = os.path.join(args.data_dir, 'train2017')
|
||||
args.annFile = os.path.join(args.data_dir, 'annotations/instances_train2017.json')
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID', '0'))
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
||||
device_target=args.device_target, save_graphs=False, device_id=device_id)
|
||||
|
||||
# init distributed
|
||||
if args.is_distributed:
|
||||
if args.device_target == "Ascend":
|
||||
|
@ -154,9 +158,6 @@ class BuildTrainNetwork(nn.Cell):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
device_id = int(os.getenv('DEVICE_ID', '0'))
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
||||
device_target=args.device_target, save_graphs=False, device_id=device_id)
|
||||
if args.need_profiler:
|
||||
profiler = Profiler(output_path=args.outputs_dir, is_detail=True, is_show_op_path=True)
|
||||
|
||||
|
|
Loading…
Reference in New Issue