forked from mindspore-Ecosystem/mindspore
!17144 Enable acceleration by Graph Kernel for VGG16 only GPU.
From: @xixixian Reviewed-by: @gaoxiong1,@ckey_dou Signed-off-by: @ckey_dou
This commit is contained in:
commit
43463b9ff1
|
@ -119,8 +119,9 @@ def merge_args(args, cloud_args):
|
||||||
def test(cloud_args=None):
|
def test(cloud_args=None):
|
||||||
"""test"""
|
"""test"""
|
||||||
args = parse_args(cloud_args)
|
args = parse_args(cloud_args)
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
_enable_graph_kernel = args.device_target == "GPU"
|
||||||
device_target=args.device_target, save_graphs=False)
|
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=_enable_graph_kernel,
|
||||||
|
enable_auto_mixed_precision=True, device_target=args.device_target, save_graphs=False)
|
||||||
if os.getenv('DEVICE_ID', "not_set").isdigit() and args.device_target == "Ascend":
|
if os.getenv('DEVICE_ID', "not_set").isdigit() and args.device_target == "Ascend":
|
||||||
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
||||||
|
|
||||||
|
|
|
@ -126,7 +126,9 @@ def merge_args(args_opt, cloud_args):
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
_enable_graph_kernel = args.device_target == "GPU"
|
||||||
|
context.set_context(mode=context.GRAPH_MODE,
|
||||||
|
enable_graph_kernel=_enable_graph_kernel, device_target=args.device_target)
|
||||||
device_num = int(os.environ.get("DEVICE_NUM", 1))
|
device_num = int(os.environ.get("DEVICE_NUM", 1))
|
||||||
if args.is_distributed:
|
if args.is_distributed:
|
||||||
if args.device_target == "Ascend":
|
if args.device_target == "Ascend":
|
||||||
|
|
Loading…
Reference in New Issue