!18142 Enable graph_kernel while train NASNet on GPU
Merge pull request !18142 from hujiahui8/graph_kernel
This commit is contained in:
commit
9c07a9c116
|
@ -48,7 +48,7 @@ if __name__ == '__main__':
|
||||||
if args_opt.platform != "GPU":
|
if args_opt.platform != "GPU":
|
||||||
raise ValueError("Only supported GPU training.")
|
raise ValueError("Only supported GPU training.")
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, enable_graph_kernel=True)
|
||||||
if os.getenv('DEVICE_ID', "not_set").isdigit():
|
if os.getenv('DEVICE_ID', "not_set").isdigit():
|
||||||
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue